diff --git a/Cargo.lock b/Cargo.lock index b4e7319c6dc..b0ef321a87d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2347,6 +2347,7 @@ dependencies = [ "futures", "libp2p", "prost", + "rand 0.8.5", "reed-solomon-simd", "rstest", "sha2 0.10.9", diff --git a/crates/apollo_propeller/Cargo.toml b/crates/apollo_propeller/Cargo.toml index eaee8810401..c80648519cd 100644 --- a/crates/apollo_propeller/Cargo.toml +++ b/crates/apollo_propeller/Cargo.toml @@ -16,6 +16,7 @@ asynchronous-codec.workspace = true futures.workspace = true libp2p.workspace = true prost.workspace = true +rand.workspace = true reed-solomon-simd.workspace = true sha2.workspace = true thiserror.workspace = true diff --git a/crates/apollo_propeller/src/message_processor.rs b/crates/apollo_propeller/src/message_processor.rs index 7a9a5e761e0..b8ba2084052 100644 --- a/crates/apollo_propeller/src/message_processor.rs +++ b/crates/apollo_propeller/src/message_processor.rs @@ -3,22 +3,116 @@ use std::sync::Arc; use std::time::Duration; use libp2p::identity::{PeerId, PublicKey}; +use rand::seq::SliceRandom; use tokio::sync::mpsc; -use tokio::time::sleep_until; -use tracing::{debug, trace}; +use tracing::{debug, error, trace}; +use crate::sharding::reconstruct_message_from_shards; use crate::tree::PropellerScheduleManager; -use crate::types::{Channel, Event, MessageRoot}; +use crate::types::{Channel, Event, MessageRoot, ReconstructionError, ShardValidationError}; use crate::unit::PropellerUnit; -use crate::ShardIndex; +use crate::unit_validator::UnitValidator; +use crate::{MerkleProof, ShardIndex}; pub type UnitToValidate = (PeerId, PropellerUnit); +type ValidationResult = (Result<(), ShardValidationError>, UnitValidator, PropellerUnit); +type ReconstructionResult = Result; #[derive(Debug)] pub enum EventStateManagerToEngine { BehaviourEvent(Event), Finalized { channel: Channel, publisher: PeerId, message_root: MessageRoot }, - BroadcastUnit { unit: PropellerUnit, peers: Vec }, + SendUnitToPeers { unit: PropellerUnit, peers: Vec }, +} + +#[derive(Debug)] +struct ReconstructionOutput { + message: Vec, + my_shard: Vec, + my_shard_proof: MerkleProof, +} + +/// Tracks reconstruction progress for a single message. +enum ReconstructionState { + PreConstruction { + received_shards: Vec, + did_broadcast_my_shard: bool, + signature: Option>, + }, + /// Message was reconstructed but not yet delivered to the application. We keep collecting + /// shards until the access threshold is reached, then emit the message. + // No need to track the unit indices after reconstruction (unit duplication already validated) + PostConstruction { reconstructed_message: Option>, num_received_shards: usize }, +} + +impl ReconstructionState { + fn new() -> Self { + Self::PreConstruction { + received_shards: Vec::new(), + did_broadcast_my_shard: false, + signature: None, + } + } + + fn is_reconstructed(&self) -> bool { + matches!(self, Self::PostConstruction { .. }) + } + + fn did_broadcast_my_shard(&self) -> bool { + match self { + Self::PreConstruction { did_broadcast_my_shard, .. } => *did_broadcast_my_shard, + Self::PostConstruction { .. } => true, + } + } + + fn record_shard(&mut self, is_my_shard: bool) { + match self { + Self::PreConstruction { did_broadcast_my_shard, .. } => { + if is_my_shard { + *did_broadcast_my_shard = true; + } + } + Self::PostConstruction { num_received_shards, .. } => { + if !is_my_shard { + *num_received_shards += 1; + } + } + } + } + + fn capture_signature(&mut self, unit: &PropellerUnit) { + if let Self::PreConstruction { signature, .. } = self { + if signature.is_none() { + *signature = Some(unit.signature().to_vec()); + } + // The signature was already validated to be the same for all units. + } + } + + fn take_shards(&mut self) -> Vec { + match self { + Self::PreConstruction { received_shards, .. } => std::mem::take(received_shards), + Self::PostConstruction { .. } => Vec::new(), + } + } + + fn push_shard(&mut self, unit: PropellerUnit) { + if let Self::PreConstruction { received_shards, .. } = self { + received_shards.push(unit); + } + } + + fn received_shard_count(&self) -> usize { + match self { + Self::PreConstruction { received_shards, .. } => received_shards.len(), + Self::PostConstruction { .. } => 0, + } + } + + fn transition_to_post(&mut self, message: Vec, num_received_shards: usize) { + *self = + Self::PostConstruction { reconstructed_message: Some(message), num_received_shards }; + } } /// Message processor that handles validation and state management for a single message. @@ -47,18 +141,10 @@ impl MessageProcessor { self.channel, self.publisher, self.message_root ); - // Local state variables - let deadline = tokio::time::Instant::now() + self.timeout; + let timed_out = tokio::time::timeout(self.timeout, self.process_units()).await.is_err(); - // TODO(AndrewL): remove this - #[allow(clippy::never_loop)] - loop { - tokio::select! { - _ = sleep_until(deadline) => { - let _ = self.emit_timeout_and_finalize().await; - break; - } - } + if timed_out { + self.emit_timeout_and_finalize(); } debug!( @@ -67,25 +153,227 @@ impl MessageProcessor { ); } - async fn emit_timeout_and_finalize(&mut self) -> ControlFlow<()> { + async fn process_units(&mut self) { + let mut validator = UnitValidator::new( + self.channel, + self.publisher, + self.publisher_public_key.clone(), + self.message_root, + Arc::clone(&self.tree_manager), + ); + let mut state = ReconstructionState::new(); + + while let Some((sender, unit)) = self.unit_rx.recv().await { + // TODO(AndrewL): finalize immediately if first validation fails (DOS attack vector) + trace!("[MSG_PROC] Validating unit from sender={:?} index={:?}", sender, unit.index()); + + // TODO(AndrewL): process multiple shards simultaneously instead of sequentially. + let (result, returned_validator, unit) = + Self::validate_blocking(validator, sender, unit).await; + validator = returned_validator; + + if let Err(err) = result { + // TODO(AndrewL): penalize sender of bad shard. + trace!("[MSG_PROC] Validation failed for index={:?}: {:?}", unit.index(), err); + continue; + } + + self.maybe_broadcast_my_shard(&unit, &state); + state.record_shard(unit.index() == self.my_shard_index); + state.capture_signature(&unit); + + if self.update_state(unit, &mut state).await.is_break() { + return; + } + } + trace!( - "[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}", + "[MSG_PROC] All channels closed for channel={:?} publisher={:?} root={:?}", self.channel, self.publisher, self.message_root ); + self.finalize(); + } - self.emit_and_finalize(Event::MessageTimeout { - channel: self.channel, + /// Offloads CPU-bound validation (signature verification, merkle proofs) to a blocking thread + /// to avoid blocking the tokio runtime. + async fn validate_blocking( + mut validator: UnitValidator, + sender: PeerId, + unit: PropellerUnit, + ) -> ValidationResult { + tokio::task::spawn_blocking(move || { + let result = validator.validate_shard(sender, &unit); + (result, validator, unit) + }) + .await + .expect("Validation task panicked") + } + + fn maybe_broadcast_my_shard(&self, unit: &PropellerUnit, state: &ReconstructionState) { + if unit.index() == self.my_shard_index && !state.did_broadcast_my_shard() { + self.broadcast_unit(unit); + } + } + + fn broadcast_unit(&self, unit: &PropellerUnit) { + let mut peers: Vec = self + .tree_manager + .get_nodes() + .iter() + .map(|(p, _)| *p) + .filter(|p| *p != self.publisher && *p != self.local_peer_id) + .collect(); + // TODO(AndrewL): get seeded RNG for tests. + peers.shuffle(&mut rand::thread_rng()); + trace!("[MSG_PROC] Broadcasting unit index={:?} to {} peers", unit.index(), peers.len()); + self.engine_tx + .send(EventStateManagerToEngine::SendUnitToPeers { unit: unit.clone(), peers }) + .expect("Engine task has exited"); + } + + async fn update_state( + &self, + unit: PropellerUnit, + state: &mut ReconstructionState, + ) -> ControlFlow<()> { + if state.is_reconstructed() { + return self.maybe_emit_message(state); + } + + state.push_shard(unit); + + let shard_count = state.received_shard_count(); + if !self.tree_manager.should_build(shard_count) { + return ControlFlow::Continue(()); + } + + trace!("[MSG_PROC] Starting reconstruction with {} shards", shard_count); + + match self.reconstruct_blocking(state).await { + Ok(output) => self.handle_reconstruction_output(output, shard_count, state), + Err(e) => { + error!("[MSG_PROC] Reconstruction failed: {:?}", e); + self.emit_and_finalize(Event::MessageReconstructionFailed { + publisher: self.publisher, + message_root: self.message_root, + error: e, + }) + } + } + } + + /// Offloads erasure-coding reconstruction to a blocking thread. + async fn reconstruct_blocking(&self, state: &mut ReconstructionState) -> ReconstructionResult { + let shards = state.take_shards(); + let message_root = self.message_root; + let my_index: usize = self.my_shard_index.0.try_into().unwrap(); + let data_count = self.tree_manager.num_data_shards(); + let coding_count = self.tree_manager.num_coding_shards(); + + tokio::task::spawn_blocking(move || { + reconstruct_message_from_shards( + shards, + message_root, + my_index, + data_count, + coding_count, + ) + .map(|(message, my_shard, my_shard_proof)| ReconstructionOutput { + message, + my_shard, + my_shard_proof, + }) + }) + .await + .expect("Reconstruction task panicked") + } + + fn handle_reconstruction_output( + &self, + output: ReconstructionOutput, + shard_count: usize, + state: &mut ReconstructionState, + ) -> ControlFlow<()> { + let ReconstructionOutput { message, my_shard, my_shard_proof } = output; + + let should_broadcast = !state.did_broadcast_my_shard(); + if should_broadcast { + let signature = match state { + ReconstructionState::PreConstruction { signature, .. } => { + signature.clone().expect("Signature must exist") + } + ReconstructionState::PostConstruction { .. } => { + unreachable!("Cannot be PostConstruction before transition") + } + }; + let reconstructed_unit = PropellerUnit::new( + self.channel, + self.publisher, + self.message_root, + signature, + self.my_shard_index, + my_shard, + my_shard_proof, + ); + self.broadcast_unit(&reconstructed_unit); + } + + let total_shards = shard_count + usize::from(should_broadcast); + state.transition_to_post(message, total_shards); + self.maybe_emit_message(state) + } + + fn maybe_emit_message(&self, state: &mut ReconstructionState) -> ControlFlow<()> { + let num = match state { + ReconstructionState::PostConstruction { num_received_shards, .. } => { + *num_received_shards + } + _ => return ControlFlow::Continue(()), + }; + + if !self.tree_manager.should_receive(num) { + return ControlFlow::Continue(()); + } + + trace!("[MSG_PROC] Access threshold reached, emitting message"); + let message = match state { + ReconstructionState::PostConstruction { reconstructed_message, .. } => { + reconstructed_message.take().expect("Message already emitted") + } + _ => unreachable!(), + }; + self.emit_and_finalize(Event::MessageReceived { publisher: self.publisher, message_root: self.message_root, + message, }) } + fn emit_timeout_and_finalize(&self) { + trace!( + "[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}", + self.channel, + self.publisher, + self.message_root + ); + let _ = self.emit_and_finalize(Event::MessageTimeout { + channel: self.channel, + publisher: self.publisher, + message_root: self.message_root, + }); + } + fn emit_and_finalize(&self, event: Event) -> ControlFlow<()> { self.engine_tx .send(EventStateManagerToEngine::BehaviourEvent(event)) .expect("Engine task has exited"); + self.finalize(); + ControlFlow::Break(()) + } + + fn finalize(&self) { self.engine_tx .send(EventStateManagerToEngine::Finalized { channel: self.channel, @@ -93,6 +381,5 @@ impl MessageProcessor { message_root: self.message_root, }) .expect("Engine task has exited"); - ControlFlow::Break(()) } }