Skip to content

Commit b3e5579

Browse files
feat(gossipsub): allowed spawning tasks for message verification
1 parent 78bad44 commit b3e5579

File tree

3 files changed

+240
-15
lines changed

3 files changed

+240
-15
lines changed

protocols/gossipsub/src/behaviour/tests.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6654,6 +6654,7 @@ fn test_validation_error_message_size_too_large_topic_specific() {
66546654
Config::default_max_transmit_size() * 2,
66556655
ValidationMode::None,
66566656
max_transmit_size_map,
6657+
None,
66576658
);
66586659
let mut buf = BytesMut::new();
66596660
let rpc = proto::RPC {
@@ -6758,6 +6759,7 @@ fn test_validation_message_size_within_topic_specific() {
67586759
Config::default_max_transmit_size() * 2,
67596760
ValidationMode::None,
67606761
max_transmit_size_map,
6762+
None,
67616763
);
67626764
let mut buf = BytesMut::new();
67636765
let rpc = proto::RPC {

protocols/gossipsub/src/config.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1919
// DEALINGS IN THE SOFTWARE.
2020

21-
use std::{borrow::Cow, collections::HashMap, sync::Arc, time::Duration};
21+
use std::{borrow::Cow, collections::HashMap, future::Future, pin::Pin, sync::Arc, time::Duration};
2222

2323
use libp2p_identity::PeerId;
2424
use libp2p_swarm::StreamProtocol;
@@ -133,6 +133,17 @@ pub struct Config {
133133
idontwant_message_size_threshold: usize,
134134
idontwant_on_publish: bool,
135135
topic_configuration: TopicConfigs,
136+
message_verification_spawner: Option<
137+
Arc<
138+
dyn Fn(
139+
Box<dyn FnOnce() -> crate::protocol::ValidationResult + Send>,
140+
)
141+
-> Pin<Box<dyn Future<Output = crate::protocol::ValidationResult> + Send>>
142+
+ Send
143+
+ Sync
144+
+ 'static,
145+
>,
146+
>,
136147
}
137148

138149
impl Config {
@@ -476,6 +487,27 @@ impl Config {
476487
pub fn idontwant_on_publish(&self) -> bool {
477488
self.idontwant_on_publish
478489
}
490+
491+
/// Optional spawner for message verification.
492+
///
493+
/// This allows users to provide a custom spawner that takes a closure and runs it,
494+
/// returning a future that resolves to a ValidationResult. This can be used to run
495+
/// message verification on a different thread or async runtime.
496+
pub fn message_verification_spawner(
497+
&self,
498+
) -> Option<
499+
&Arc<
500+
dyn Fn(
501+
Box<dyn FnOnce() -> crate::protocol::ValidationResult + Send>,
502+
)
503+
-> Pin<Box<dyn Future<Output = crate::protocol::ValidationResult> + Send>>
504+
+ Send
505+
+ Sync
506+
+ 'static,
507+
>,
508+
> {
509+
self.message_verification_spawner.as_ref()
510+
}
479511
}
480512

481513
impl Default for Config {
@@ -546,6 +578,7 @@ impl Default for ConfigBuilder {
546578
idontwant_message_size_threshold: 1000,
547579
idontwant_on_publish: false,
548580
topic_configuration: TopicConfigs::default(),
581+
message_verification_spawner: None,
549582
},
550583
invalid_protocol: false,
551584
}
@@ -1084,6 +1117,36 @@ impl ConfigBuilder {
10841117
self
10851118
}
10861119

1120+
/// Sets a custom spawner for message verification.
1121+
///
1122+
/// The spawner takes a closure that returns a ValidationResult and should execute it,
1123+
/// returning a future that resolves to the result. This allows running message
1124+
/// verification on a different thread or async runtime. If not set, message
1125+
/// verification will run synchronously.
1126+
///
1127+
/// # Example
1128+
/// ```rust
1129+
/// libp2p_gossipsub::ConfigBuilder::default()
1130+
/// .message_verification_spawner(|closure| {
1131+
/// Box::pin(async move { tokio::task::spawn_blocking(closure).await.unwrap() })
1132+
/// })
1133+
/// .build()
1134+
/// .unwrap();
1135+
/// ```
1136+
pub fn message_verification_spawner<F>(&mut self, spawner: F) -> &mut Self
1137+
where
1138+
F: Fn(
1139+
Box<dyn FnOnce() -> crate::protocol::ValidationResult + Send>,
1140+
)
1141+
-> Pin<Box<dyn Future<Output = crate::protocol::ValidationResult> + Send>>
1142+
+ Send
1143+
+ Sync
1144+
+ 'static,
1145+
{
1146+
self.config.message_verification_spawner = Some(Arc::new(spawner));
1147+
self
1148+
}
1149+
10871150
/// Constructs a [`Config`] from the given configuration and validates the settings.
10881151
pub fn build(&self) -> Result<Config, ConfigBuilderError> {
10891152
// check all constraints on config
@@ -1304,4 +1367,46 @@ mod test {
13041367
v.push('e');
13051368
MessageId::from(v)
13061369
}
1370+
1371+
#[tokio::test]
1372+
async fn test_message_verification_spawner_with_validation_result() {
1373+
use std::sync::{
1374+
atomic::{AtomicBool, Ordering},
1375+
Arc,
1376+
};
1377+
1378+
let spawner_called = Arc::new(AtomicBool::new(false));
1379+
let spawner_called_clone = spawner_called.clone();
1380+
1381+
let config = ConfigBuilder::default()
1382+
.message_verification_spawner(move |validation_fn| {
1383+
spawner_called_clone.store(true, Ordering::Relaxed);
1384+
Box::pin(async move { validation_fn() })
1385+
})
1386+
.build()
1387+
.unwrap();
1388+
1389+
assert!(config.message_verification_spawner().is_some());
1390+
1391+
// Test that the spawner can be used
1392+
if let Some(spawner) = config.message_verification_spawner() {
1393+
let result = spawner(Box::new(|| crate::protocol::ValidationResult::Valid {
1394+
source: None,
1395+
sequence_number: None,
1396+
}));
1397+
1398+
match result.await {
1399+
crate::protocol::ValidationResult::Valid { .. } => {
1400+
// Test passed
1401+
}
1402+
_ => panic!("Expected Valid result"),
1403+
}
1404+
}
1405+
1406+
// Verify the spawner was called
1407+
assert!(
1408+
spawner_called.load(Ordering::Relaxed),
1409+
"Spawner should have been called"
1410+
);
1411+
}
13071412
}

protocols/gossipsub/src/protocol.rs

Lines changed: 132 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,14 @@
1818
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
1919
// DEALINGS IN THE SOFTWARE.
2020

21-
use std::{collections::HashMap, convert::Infallible, pin::Pin};
21+
use std::{
22+
collections::HashMap,
23+
convert::Infallible,
24+
future::Future,
25+
pin::Pin,
26+
sync::Arc,
27+
task::{Context, Poll},
28+
};
2229

2330
use asynchronous_codec::{Decoder, Encoder, Framed};
2431
use byteorder::{BigEndian, ByteOrder};
@@ -62,7 +69,7 @@ pub(crate) const FLOODSUB_PROTOCOL: ProtocolId = ProtocolId {
6269
};
6370

6471
/// Implementation of [`InboundUpgrade`] and [`OutboundUpgrade`] for the Gossipsub protocol.
65-
#[derive(Debug, Clone)]
72+
#[derive(Clone)]
6673
pub struct ProtocolConfig {
6774
/// The Gossipsub protocol id to listen on.
6875
pub(crate) protocol_ids: Vec<ProtocolId>,
@@ -72,6 +79,17 @@ pub struct ProtocolConfig {
7279
pub(crate) default_max_transmit_size: usize,
7380
/// The max transmit sizes for a topic.
7481
pub(crate) max_transmit_sizes: HashMap<TopicHash, usize>,
82+
/// Optional spawner for message verification.
83+
pub(crate) message_verification_spawner: Option<
84+
Arc<
85+
dyn Fn(
86+
Box<dyn FnOnce() -> ValidationResult + Send>,
87+
) -> Pin<Box<dyn Future<Output = ValidationResult> + Send>>
88+
+ Send
89+
+ Sync
90+
+ 'static,
91+
>,
92+
>,
7593
}
7694

7795
impl Default for ProtocolConfig {
@@ -85,10 +103,26 @@ impl Default for ProtocolConfig {
85103
],
86104
default_max_transmit_size: 65536,
87105
max_transmit_sizes: HashMap::new(),
106+
message_verification_spawner: None,
88107
}
89108
}
90109
}
91110

111+
impl std::fmt::Debug for ProtocolConfig {
112+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113+
f.debug_struct("ProtocolConfig")
114+
.field("protocol_ids", &self.protocol_ids)
115+
.field("validation_mode", &self.validation_mode)
116+
.field("default_max_transmit_size", &self.default_max_transmit_size)
117+
.field("max_transmit_sizes", &self.max_transmit_sizes)
118+
.field(
119+
"message_verification_spawner",
120+
&self.message_verification_spawner.is_some(),
121+
)
122+
.finish()
123+
}
124+
}
125+
92126
impl ProtocolConfig {
93127
/// Get the max transmit size for a given topic, falling back to the default.
94128
pub fn max_transmit_size_for_topic(&self, topic: &TopicHash) -> usize {
@@ -139,6 +173,7 @@ where
139173
self.default_max_transmit_size,
140174
self.validation_mode,
141175
self.max_transmit_sizes,
176+
self.message_verification_spawner,
142177
),
143178
),
144179
protocol_id.kind,
@@ -162,6 +197,7 @@ where
162197
self.default_max_transmit_size,
163198
self.validation_mode,
164199
self.max_transmit_sizes,
200+
self.message_verification_spawner,
165201
),
166202
),
167203
protocol_id.kind,
@@ -178,11 +214,22 @@ pub struct GossipsubCodec {
178214
codec: quick_protobuf_codec::Codec<proto::RPC>,
179215
/// Maximum transmit sizes per topic, with a default if not specified.
180216
max_transmit_sizes: HashMap<TopicHash, usize>,
217+
/// Optional spawner for message verification.
218+
message_verification_spawner: Option<
219+
Arc<
220+
dyn Fn(
221+
Box<dyn FnOnce() -> ValidationResult + Send>,
222+
) -> Pin<Box<dyn Future<Output = ValidationResult> + Send>>
223+
+ Send
224+
+ Sync
225+
+ 'static,
226+
>,
227+
>,
181228
}
182229

183230
/// Result of message validation
184231
#[derive(Debug)]
185-
enum ValidationResult {
232+
pub enum ValidationResult {
186233
Valid {
187234
source: Option<PeerId>,
188235
sequence_number: Option<u64>,
@@ -195,12 +242,24 @@ impl GossipsubCodec {
195242
max_length: usize,
196243
validation_mode: ValidationMode,
197244
max_transmit_sizes: HashMap<TopicHash, usize>,
245+
message_verification_spawner: Option<
246+
Arc<
247+
dyn Fn(
248+
Box<dyn FnOnce() -> ValidationResult + Send>,
249+
)
250+
-> Pin<Box<dyn Future<Output = ValidationResult> + Send>>
251+
+ Send
252+
+ Sync
253+
+ 'static,
254+
>,
255+
>,
198256
) -> GossipsubCodec {
199257
let codec = quick_protobuf_codec::Codec::new(max_length);
200258
GossipsubCodec {
201259
validation_mode,
202260
codec,
203261
max_transmit_sizes,
262+
message_verification_spawner,
204263
}
205264
}
206265

@@ -408,13 +467,70 @@ impl Decoder for GossipsubCodec {
408467
// Store any invalid messages.
409468
let mut invalid_messages = Vec::new();
410469

411-
for message in rpc.publish.into_iter() {
412-
// Validate the message using the extracted validation function
413-
match Self::validate_message(
414-
self.validation_mode.clone(),
415-
&self.max_transmit_sizes,
416-
&message,
417-
) {
470+
// Collect validation results and corresponding messages
471+
let mut validation_results = Vec::new();
472+
let mut validation_futures = Vec::new();
473+
let messages_to_process: Vec<_> = rpc.publish.into_iter().collect();
474+
475+
for message in &messages_to_process {
476+
if let Some(spawner) = &self.message_verification_spawner {
477+
// Use spawner - defer validation
478+
let validation_mode = self.validation_mode.clone();
479+
let max_transmit_sizes = self.max_transmit_sizes.clone();
480+
let message_clone = message.clone();
481+
let future = spawner(Box::new(move || {
482+
Self::validate_message(validation_mode, &max_transmit_sizes, &message_clone)
483+
}));
484+
validation_futures.push(future);
485+
} else {
486+
// No spawner - validate immediately
487+
let result = Self::validate_message(
488+
self.validation_mode.clone(),
489+
&self.max_transmit_sizes,
490+
message,
491+
);
492+
validation_results.push(result);
493+
}
494+
}
495+
496+
// Poll all validation futures until completion
497+
if !validation_futures.is_empty() {
498+
let waker = futures::task::noop_waker();
499+
let mut context = Context::from_waker(&waker);
500+
let mut future_results: Vec<Option<ValidationResult>> =
501+
(0..validation_futures.len()).map(|_| None).collect();
502+
503+
// Poll until all futures are ready
504+
loop {
505+
let mut all_ready = true;
506+
for (i, future) in validation_futures.iter_mut().enumerate() {
507+
if future_results[i].is_none() {
508+
match future.as_mut().poll(&mut context) {
509+
Poll::Ready(result) => {
510+
future_results[i] = Some(result);
511+
}
512+
Poll::Pending => {
513+
all_ready = false;
514+
std::thread::yield_now();
515+
}
516+
}
517+
}
518+
}
519+
if all_ready {
520+
break;
521+
}
522+
}
523+
524+
// Move future results into the main validation_results vector
525+
validation_results.extend(future_results.into_iter().map(|r| r.unwrap()));
526+
}
527+
528+
// Process all validation results uniformly
529+
for (message, validation_result) in messages_to_process
530+
.into_iter()
531+
.zip(validation_results.into_iter())
532+
{
533+
match validation_result {
418534
ValidationResult::Valid {
419535
source,
420536
sequence_number,
@@ -455,8 +571,6 @@ impl Decoder for GossipsubCodec {
455571
validated: false,
456572
};
457573
invalid_messages.push((raw_message, validation_error));
458-
// proceed to the next message
459-
continue;
460574
}
461575
}
462576
}
@@ -650,8 +764,12 @@ mod tests {
650764
timeout: Delay::new(Duration::from_secs(1)),
651765
};
652766

653-
let mut codec =
654-
GossipsubCodec::new(u32::MAX as usize, ValidationMode::Strict, HashMap::new());
767+
let mut codec = GossipsubCodec::new(
768+
u32::MAX as usize,
769+
ValidationMode::Strict,
770+
HashMap::new(),
771+
None,
772+
);
655773
let mut buf = BytesMut::new();
656774
codec.encode(rpc.into_protobuf(), &mut buf).unwrap();
657775
let decoded_rpc = codec.decode(&mut buf).unwrap().unwrap();

0 commit comments

Comments
 (0)