Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/apollo_propeller/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ asynchronous-codec.workspace = true
futures.workspace = true
libp2p.workspace = true
prost.workspace = true
rand.workspace = true
reed-solomon-simd.workspace = true
sha2.workspace = true
thiserror.workspace = true
Expand Down
329 changes: 308 additions & 21 deletions crates/apollo_propeller/src/message_processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,116 @@ use std::sync::Arc;
use std::time::Duration;

use libp2p::identity::{PeerId, PublicKey};
use rand::seq::SliceRandom;
use tokio::sync::mpsc;
use tokio::time::sleep_until;
use tracing::{debug, trace};
use tracing::{debug, error, trace};

use crate::sharding::reconstruct_message_from_shards;
use crate::tree::PropellerScheduleManager;
use crate::types::{Channel, Event, MessageRoot};
use crate::types::{Channel, Event, MessageRoot, ReconstructionError, ShardValidationError};
use crate::unit::PropellerUnit;
use crate::ShardIndex;
use crate::unit_validator::UnitValidator;
use crate::{MerkleProof, ShardIndex};

pub type UnitToValidate = (PeerId, PropellerUnit);
type ValidationResult = (Result<(), ShardValidationError>, UnitValidator, PropellerUnit);
type ReconstructionResult = Result<ReconstructionOutput, ReconstructionError>;

#[derive(Debug)]
pub enum EventStateManagerToEngine {
BehaviourEvent(Event),
Finalized { channel: Channel, publisher: PeerId, message_root: MessageRoot },
BroadcastUnit { unit: PropellerUnit, peers: Vec<PeerId> },
SendUnitToPeers { unit: PropellerUnit, peers: Vec<PeerId> },
}

#[derive(Debug)]
struct ReconstructionOutput {
message: Vec<u8>,
my_shard: Vec<u8>,
my_shard_proof: MerkleProof,
}

/// Tracks reconstruction progress for a single message.
enum ReconstructionState {
PreConstruction {
received_shards: Vec<PropellerUnit>,
did_broadcast_my_shard: bool,
signature: Option<Vec<u8>>,
},
/// Message was reconstructed but not yet delivered to the application. We keep collecting
/// shards until the access threshold is reached, then emit the message.
// No need to track the unit indices after reconstruction (unit duplication already validated)
PostConstruction { reconstructed_message: Option<Vec<u8>>, num_received_shards: usize },
}

impl ReconstructionState {
fn new() -> Self {
Self::PreConstruction {
received_shards: Vec::new(),
did_broadcast_my_shard: false,
signature: None,
}
}

fn is_reconstructed(&self) -> bool {
matches!(self, Self::PostConstruction { .. })
}

fn did_broadcast_my_shard(&self) -> bool {
match self {
Self::PreConstruction { did_broadcast_my_shard, .. } => *did_broadcast_my_shard,
Self::PostConstruction { .. } => true,
}
}

fn record_shard(&mut self, is_my_shard: bool) {
match self {
Self::PreConstruction { did_broadcast_my_shard, .. } => {
if is_my_shard {
*did_broadcast_my_shard = true;
}
}
Self::PostConstruction { num_received_shards, .. } => {
if !is_my_shard {
*num_received_shards += 1;
}
}
}
}

fn capture_signature(&mut self, unit: &PropellerUnit) {
if let Self::PreConstruction { signature, .. } = self {
if signature.is_none() {
*signature = Some(unit.signature().to_vec());
}
// The signature was already validated to be the same for all units.
}
}

fn take_shards(&mut self) -> Vec<PropellerUnit> {
match self {
Self::PreConstruction { received_shards, .. } => std::mem::take(received_shards),
Self::PostConstruction { .. } => Vec::new(),
}
}

fn push_shard(&mut self, unit: PropellerUnit) {
if let Self::PreConstruction { received_shards, .. } = self {
received_shards.push(unit);
}
}

fn received_shard_count(&self) -> usize {
match self {
Self::PreConstruction { received_shards, .. } => received_shards.len(),
Self::PostConstruction { .. } => 0,
}
}

fn transition_to_post(&mut self, message: Vec<u8>, num_received_shards: usize) {
*self =
Self::PostConstruction { reconstructed_message: Some(message), num_received_shards };
}
}

/// Message processor that handles validation and state management for a single message.
Expand Down Expand Up @@ -47,18 +141,10 @@ impl MessageProcessor {
self.channel, self.publisher, self.message_root
);

// Local state variables
let deadline = tokio::time::Instant::now() + self.timeout;
let timed_out = tokio::time::timeout(self.timeout, self.process_units()).await.is_err();

// TODO(AndrewL): remove this
#[allow(clippy::never_loop)]
loop {
tokio::select! {
_ = sleep_until(deadline) => {
let _ = self.emit_timeout_and_finalize().await;
break;
}
}
if timed_out {
self.emit_timeout_and_finalize();
}

debug!(
Expand All @@ -67,32 +153,233 @@ impl MessageProcessor {
);
}

async fn emit_timeout_and_finalize(&mut self) -> ControlFlow<()> {
async fn process_units(&mut self) {
let mut validator = UnitValidator::new(
self.channel,
self.publisher,
self.publisher_public_key.clone(),
self.message_root,
Arc::clone(&self.tree_manager),
);
let mut state = ReconstructionState::new();

while let Some((sender, unit)) = self.unit_rx.recv().await {
// TODO(AndrewL): finalize immediately if first validation fails (DOS attack vector)
trace!("[MSG_PROC] Validating unit from sender={:?} index={:?}", sender, unit.index());

// TODO(AndrewL): process multiple shards simultaneously instead of sequentially.
let (result, returned_validator, unit) =
Self::validate_blocking(validator, sender, unit).await;
validator = returned_validator;

if let Err(err) = result {
// TODO(AndrewL): penalize sender of bad shard.
trace!("[MSG_PROC] Validation failed for index={:?}: {:?}", unit.index(), err);
continue;
}

self.maybe_broadcast_my_shard(&unit, &state);
state.record_shard(unit.index() == self.my_shard_index);
state.capture_signature(&unit);

if self.update_state(unit, &mut state).await.is_break() {
return;
}
}

trace!(
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
"[MSG_PROC] All channels closed for channel={:?} publisher={:?} root={:?}",
self.channel,
self.publisher,
self.message_root
);
self.finalize();
}

self.emit_and_finalize(Event::MessageTimeout {
channel: self.channel,
/// Offloads CPU-bound validation (signature verification, merkle proofs) to a blocking thread
/// to avoid blocking the tokio runtime.
async fn validate_blocking(
mut validator: UnitValidator,
sender: PeerId,
unit: PropellerUnit,
) -> ValidationResult {
tokio::task::spawn_blocking(move || {
let result = validator.validate_shard(sender, &unit);
(result, validator, unit)
})
.await
.expect("Validation task panicked")
}

fn maybe_broadcast_my_shard(&self, unit: &PropellerUnit, state: &ReconstructionState) {
if unit.index() == self.my_shard_index && !state.did_broadcast_my_shard() {
self.broadcast_unit(unit);
}
}

fn broadcast_unit(&self, unit: &PropellerUnit) {
let mut peers: Vec<PeerId> = self
.tree_manager
.get_nodes()
.iter()
.map(|(p, _)| *p)
.filter(|p| *p != self.publisher && *p != self.local_peer_id)
.collect();
// TODO(AndrewL): get seeded RNG for tests.
peers.shuffle(&mut rand::thread_rng());
trace!("[MSG_PROC] Broadcasting unit index={:?} to {} peers", unit.index(), peers.len());
self.engine_tx
.send(EventStateManagerToEngine::SendUnitToPeers { unit: unit.clone(), peers })
.expect("Engine task has exited");
}

async fn update_state(
&self,
unit: PropellerUnit,
state: &mut ReconstructionState,
) -> ControlFlow<()> {
if state.is_reconstructed() {
return self.maybe_emit_message(state);
}

state.push_shard(unit);

let shard_count = state.received_shard_count();
if !self.tree_manager.should_build(shard_count) {
return ControlFlow::Continue(());
}

trace!("[MSG_PROC] Starting reconstruction with {} shards", shard_count);

match self.reconstruct_blocking(state).await {
Ok(output) => self.handle_reconstruction_output(output, shard_count, state),
Err(e) => {
error!("[MSG_PROC] Reconstruction failed: {:?}", e);
self.emit_and_finalize(Event::MessageReconstructionFailed {
publisher: self.publisher,
message_root: self.message_root,
error: e,
})
}
}
}

/// Offloads erasure-coding reconstruction to a blocking thread.
async fn reconstruct_blocking(&self, state: &mut ReconstructionState) -> ReconstructionResult {
let shards = state.take_shards();
let message_root = self.message_root;
let my_index: usize = self.my_shard_index.0.try_into().unwrap();
let data_count = self.tree_manager.num_data_shards();
let coding_count = self.tree_manager.num_coding_shards();

tokio::task::spawn_blocking(move || {
reconstruct_message_from_shards(
shards,
message_root,
my_index,
data_count,
coding_count,
)
.map(|(message, my_shard, my_shard_proof)| ReconstructionOutput {
message,
my_shard,
my_shard_proof,
})
})
.await
.expect("Reconstruction task panicked")
}

fn handle_reconstruction_output(
&self,
output: ReconstructionOutput,
shard_count: usize,
state: &mut ReconstructionState,
) -> ControlFlow<()> {
let ReconstructionOutput { message, my_shard, my_shard_proof } = output;

let should_broadcast = !state.did_broadcast_my_shard();
if should_broadcast {
let signature = match state {
ReconstructionState::PreConstruction { signature, .. } => {
signature.clone().expect("Signature must exist")
}
ReconstructionState::PostConstruction { .. } => {
unreachable!("Cannot be PostConstruction before transition")
}
};
let reconstructed_unit = PropellerUnit::new(
self.channel,
self.publisher,
self.message_root,
signature,
self.my_shard_index,
my_shard,
my_shard_proof,
);
self.broadcast_unit(&reconstructed_unit);
}

let total_shards = shard_count + usize::from(should_broadcast);
state.transition_to_post(message, total_shards);
self.maybe_emit_message(state)
}

fn maybe_emit_message(&self, state: &mut ReconstructionState) -> ControlFlow<()> {
let num = match state {
ReconstructionState::PostConstruction { num_received_shards, .. } => {
*num_received_shards
}
_ => return ControlFlow::Continue(()),
};

if !self.tree_manager.should_receive(num) {
return ControlFlow::Continue(());
}

trace!("[MSG_PROC] Access threshold reached, emitting message");
let message = match state {
ReconstructionState::PostConstruction { reconstructed_message, .. } => {
reconstructed_message.take().expect("Message already emitted")
}
_ => unreachable!(),
};
self.emit_and_finalize(Event::MessageReceived {
publisher: self.publisher,
message_root: self.message_root,
message,
})
}

fn emit_timeout_and_finalize(&self) {
trace!(
"[MSG_PROC] Timeout reached for channel={:?} publisher={:?} root={:?}",
self.channel,
self.publisher,
self.message_root
);
let _ = self.emit_and_finalize(Event::MessageTimeout {
channel: self.channel,
publisher: self.publisher,
message_root: self.message_root,
});
}

fn emit_and_finalize(&self, event: Event) -> ControlFlow<()> {
self.engine_tx
.send(EventStateManagerToEngine::BehaviourEvent(event))
.expect("Engine task has exited");
self.finalize();
ControlFlow::Break(())
}

fn finalize(&self) {
self.engine_tx
.send(EventStateManagerToEngine::Finalized {
channel: self.channel,
publisher: self.publisher,
message_root: self.message_root,
})
.expect("Engine task has exited");
ControlFlow::Break(())
}
}
Loading