Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions crates/net/network/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ pub struct NetworkConfig<C, N: NetworkPrimitives = EthNetworkPrimitives> {
/// If non-empty, peers that don't have these blocks will be filtered out.
pub required_block_hashes: Vec<B256>,
/// A transformation hook applied to the downloaded headers.
pub header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
pub header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
}

// === impl NetworkConfig ===
Expand Down Expand Up @@ -232,7 +232,7 @@ pub struct NetworkConfigBuilder<N: NetworkPrimitives = EthNetworkPrimitives> {
/// Optional network id
network_id: Option<u64>,
/// The header transform type.
header_transform: Option<Box<dyn HeaderTransform<N::BlockHeader>>>,
header_transform: Option<Arc<dyn HeaderTransform<N::BlockHeader>>>,
}

impl NetworkConfigBuilder<EthNetworkPrimitives> {
Expand Down Expand Up @@ -605,7 +605,7 @@ impl<N: NetworkPrimitives> NetworkConfigBuilder<N> {
/// Sets the header transform type.
pub fn header_transform(
mut self,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
self.header_transform = Some(header_transform);
self
Expand Down Expand Up @@ -717,7 +717,7 @@ impl<N: NetworkPrimitives> NetworkConfigBuilder<N> {
nat,
handshake,
required_block_hashes,
header_transform: header_transform.unwrap_or_else(|| Box::new(())),
header_transform: header_transform.unwrap_or_else(|| Arc::new(())),
}
}
}
Expand Down
31 changes: 19 additions & 12 deletions crates/net/network/src/fetch/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub use client::FetchClient;

use crate::{message::BlockRequest, session::BlockRangeInfo, transform::header::HeaderTransform};
use alloy_primitives::B256;
use futures::StreamExt;
use futures::{future::join_all, StreamExt};
use reth_eth_wire::{EthNetworkPrimitives, GetBlockBodies, GetBlockHeaders, NetworkPrimitives};
use reth_network_api::test_utils::PeersHandle;
use reth_network_p2p::{
Expand Down Expand Up @@ -56,7 +56,7 @@ pub struct StateFetcher<N: NetworkPrimitives = EthNetworkPrimitives> {
/// Sender for download requests, used to detach a [`FetchClient`]
download_requests_tx: UnboundedSender<DownloadRequest<N>>,
/// A transformation hook applied to the downloaded headers.
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
}

// === impl StateSyncer ===
Expand All @@ -65,7 +65,7 @@ impl<N: NetworkPrimitives> StateFetcher<N> {
pub(crate) fn new(
peers_handle: PeersHandle,
num_active_peers: Arc<AtomicUsize>,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
let (download_requests_tx, download_requests_rx) = mpsc::unbounded_channel();
Self {
Expand Down Expand Up @@ -279,10 +279,17 @@ impl<N: NetworkPrimitives> StateFetcher<N> {
resp.as_ref().is_some_and(|r| res.is_likely_bad_headers_response(&r.request));

if let Some(resp) = resp {
// apply the header transform and delegate the response
let _ = resp.response.send(res.map(|h| {
(peer_id, h.into_iter().map(|h| self.header_transform.map(h)).collect()).into()
}));
let header_transform = self.header_transform.clone();
tokio::spawn(async move {
let res = match res {
Ok(headers) => {
Ok(join_all(headers.into_iter().map(|h| header_transform.map(h))).await)
}
Err(e) => Err(e),
};

let _ = resp.response.send(res.map(|h| (peer_id, h).into()));
});
}

if let Some(peer) = self.peers.get_mut(&peer_id) {
Expand Down Expand Up @@ -496,7 +503,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);

poll_fn(move |cx| {
Expand All @@ -521,7 +528,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
// Add a few random peers
let peer1 = B512::random();
Expand All @@ -548,7 +555,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
// Add a few random peers
let peer1 = B512::random();
Expand Down Expand Up @@ -577,7 +584,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
let peer_id = B512::random();

Expand Down Expand Up @@ -611,7 +618,7 @@ mod tests {
let mut fetcher = StateFetcher::<EthNetworkPrimitives>::new(
manager.handle(),
Default::default(),
Box::new(()),
Arc::new(()),
);
let peer_id = B512::random();

Expand Down
4 changes: 2 additions & 2 deletions crates/net/network/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ impl<N: NetworkPrimitives> NetworkState<N> {
discovery: Discovery,
peers_manager: PeersManager,
num_active_peers: Arc<AtomicUsize>,
header_transform: Box<dyn HeaderTransform<N::BlockHeader>>,
header_transform: Arc<dyn HeaderTransform<N::BlockHeader>>,
) -> Self {
let state_fetcher =
StateFetcher::new(peers_manager.handle(), num_active_peers, header_transform);
Expand Down Expand Up @@ -582,7 +582,7 @@ mod tests {
queued_messages: Default::default(),
client: BlockNumReader(Box::new(NoopProvider::default())),
discovery: Discovery::noop(),
state_fetcher: StateFetcher::new(handle, Default::default(), Box::new(())),
state_fetcher: StateFetcher::new(handle, Default::default(), Arc::new(())),
}
}

Expand Down
6 changes: 4 additions & 2 deletions crates/net/network/src/transform/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
use reth_primitives_traits::BlockHeader;

/// An instance of the trait applies a mapping to the input header.
#[async_trait::async_trait]
pub trait HeaderTransform<H: BlockHeader>: std::fmt::Debug + Send + Sync {
/// Applies a mapping to the input header.
fn map(&self, header: H) -> H;
async fn map(&self, header: H) -> H;
}

#[async_trait::async_trait]
impl<H: BlockHeader> HeaderTransform<H> for () {
fn map(&self, header: H) -> H {
async fn map(&self, header: H) -> H {
header
}
}
Expand Down
Loading