Skip to content

Commit 8f2716e

Browse files
Custom protocol (#158)
* add OnceCell protocol bytes * store explicit arrays * include pr changes * adjust config to use bytes * initialize oncecells * adjust for tests * comment for clarity * fmt * remove everything goes allow in kbucket key * clippy
1 parent 1bc428f commit 8f2716e

File tree

5 files changed

+100
-18
lines changed

5 files changed

+100
-18
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ tokio-util = { version = "0.6.9", features = ["time"] }
1919
libp2p-core = { version = "0.36.0", optional = true }
2020
zeroize = { version = "1.4.3", features = ["zeroize_derive"] }
2121
futures = "0.3.19"
22-
uint = { version = "0.9.1", default-features = false }
22+
uint = { version = "0.9.5", default-features = false }
2323
rlp = "0.5.1"
2424
# This version must be kept up to date do it uses the same dependencies as ENR
2525
hkdf = "0.12.3"
@@ -39,6 +39,7 @@ lru = "0.7.1"
3939
hashlink = "0.7.0"
4040
delay_map = "0.1.1"
4141
more-asserts = "0.2.2"
42+
once_cell = "1.17.0"
4243

4344
[dev-dependencies]
4445
rand_07 = { package = "rand", version = "0.7" }

src/config.rs

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
use crate::{
2-
ipmode::IpMode, kbucket::MAX_NODES_PER_BUCKET, Enr, Executor, PermitBanList, RateLimiter,
3-
RateLimiterBuilder,
2+
ipmode::IpMode,
3+
kbucket::MAX_NODES_PER_BUCKET,
4+
packet::{PROTOCOL_ID_LENGTH, PROTOCOL_VERSION_LENGTH},
5+
Enr, Executor, PermitBanList, RateLimiter, RateLimiterBuilder,
46
};
57
///! A set of configuration parameters to tune the discovery protocol.
68
use std::time::Duration;
9+
use std::{convert::TryInto, num::NonZeroU16};
10+
11+
/// Protocol ID sent with each message.
12+
pub(crate) const DEFAULT_PROTOCOL_ID: [u8; PROTOCOL_ID_LENGTH] = *b"discv5";
13+
/// The version sent with each handshake.
14+
pub(crate) const DEFAULT_PROTOCOL_VERSION: [u8; PROTOCOL_VERSION_LENGTH] = 0x0001_u16.to_be_bytes();
715

816
/// Configuration parameters that define the performance of the discovery network.
917
#[derive(Clone)]
@@ -99,6 +107,9 @@ pub struct Discv5Config {
99107
/// A custom executor which can spawn the discv5 tasks. This must be a tokio runtime, with
100108
/// timing support. By default, the executor that created the discv5 struct will be used.
101109
pub executor: Option<Box<dyn Executor + Send + Sync>>,
110+
111+
/// The Discv5 protocol id and version, in bytes.
112+
pub protocol: ([u8; PROTOCOL_ID_LENGTH], [u8; PROTOCOL_VERSION_LENGTH]),
102113
}
103114

104115
impl Default for Discv5Config {
@@ -138,6 +149,7 @@ impl Default for Discv5Config {
138149
ban_duration: Some(Duration::from_secs(3600)), // 1 hour
139150
ip_mode: IpMode::default(),
140151
executor: None,
152+
protocol: (DEFAULT_PROTOCOL_ID, DEFAULT_PROTOCOL_VERSION),
141153
}
142154
}
143155
}
@@ -314,6 +326,30 @@ impl Discv5ConfigBuilder {
314326
self
315327
}
316328

329+
/// Set the discv5 wire protocol id.
330+
pub fn protocol_id(&mut self, protocol_id: &'static str) -> &mut Self {
331+
let protocol_id: [u8; PROTOCOL_ID_LENGTH] =
332+
protocol_id.as_bytes().try_into().unwrap_or_else(|_| {
333+
panic!("The protocol id must be {} bytes long", PROTOCOL_ID_LENGTH)
334+
});
335+
336+
self.config.protocol = (protocol_id, DEFAULT_PROTOCOL_VERSION);
337+
self
338+
}
339+
340+
/// Set the discv5 wire protocol id and the version.
341+
pub fn protocol(
342+
&mut self,
343+
protocol_id: &'static str,
344+
protocol_version: NonZeroU16,
345+
) -> &mut Self {
346+
self.protocol_id(protocol_id);
347+
let protocol_version = protocol_version.get().to_be_bytes();
348+
349+
self.config.protocol.1 = protocol_version;
350+
self
351+
}
352+
317353
pub fn build(&mut self) -> Discv5Config {
318354
// If an executor is not provided, assume a current tokio runtime is running.
319355
if self.config.executor.is_none() {
@@ -347,6 +383,7 @@ impl std::fmt::Debug for Discv5Config {
347383
.field("incoming_bucket_limit", &self.incoming_bucket_limit)
348384
.field("ping_interval", &self.ping_interval)
349385
.field("ban_duration", &self.ban_duration)
386+
.field("protocol", &self.protocol)
350387
.finish()
351388
}
352389
}

src/discv5.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,39 @@ impl Discv5 {
9191
enr_key: CombinedKey,
9292
mut config: Discv5Config,
9393
) -> Result<Self, &'static str> {
94+
// tests use the default value, so we ignore initializing the protocol.
95+
#[cfg(not(test))]
96+
{
97+
use crate::{
98+
config::{DEFAULT_PROTOCOL_ID, DEFAULT_PROTOCOL_VERSION},
99+
packet::{PROTOCOL_ID, VERSION},
100+
};
101+
// initialize the protocol id and version
102+
let (protocol_id_bytes, protocol_version_bytes) = config.protocol;
103+
PROTOCOL_ID
104+
.set(protocol_id_bytes)
105+
.map_err(|_old_val| "PROTOCOL_ID has already been initialized")?;
106+
VERSION
107+
.set(protocol_version_bytes)
108+
.map_err(|_old_val| "protocol's VERSION has already been initialized")?;
109+
110+
if protocol_id_bytes != DEFAULT_PROTOCOL_ID
111+
|| protocol_version_bytes != DEFAULT_PROTOCOL_VERSION
112+
{
113+
let protocol_version = u16::from_be_bytes(protocol_version_bytes);
114+
match std::str::from_utf8(&protocol_id_bytes) {
115+
Ok(pretty_protocol_id) => tracing::info!(
116+
"Discv5 using custom protocol id and version. Id: {} Version: {}",
117+
pretty_protocol_id, protocol_version
118+
),
119+
Err(_) => tracing::info!(
120+
"Discv5 using custom protocol id and version, with non utf8 protocol id. Id: {:?} Version: {}",
121+
protocol_id_bytes, protocol_version
122+
),
123+
}
124+
}
125+
}
126+
94127
// ensure the keypair matches the one that signed the enr.
95128
if local_enr.public_key() != enr_key.public() {
96129
return Err("Provided keypair does not match the provided ENR");

src/kbucket/key.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
// This basis of this file has been taken from the rust-libp2p codebase:
2222
// https://github.com/libp2p/rust-libp2p
2323

24-
#![allow(clippy::all)]
25-
24+
#![allow(clippy::assign_op_pattern)]
2625
use enr::{
2726
k256::sha2::digest::generic_array::{typenum::U32, GenericArray},
2827
NodeId,

src/packet/mod.rs

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@ use aes::{
1515
Aes128Ctr,
1616
};
1717
use enr::NodeId;
18+
use once_cell::sync::OnceCell;
1819
use rand::Rng;
1920
use std::convert::TryInto;
2021
use zeroize::Zeroize;
2122

2223
/// The packet IV length (u128).
2324
pub const IV_LENGTH: usize = 16;
25+
/// Length of the protocol id.
26+
pub const PROTOCOL_ID_LENGTH: usize = 6;
27+
/// Length of the protocol version.
28+
pub const PROTOCOL_VERSION_LENGTH: usize = 2;
2429
/// The length of the static header. (6 byte protocol id, 2 bytes version, 1 byte kind, 12 byte
2530
/// message nonce and a 2 byte authdata-size).
2631
pub const STATIC_HEADER_LENGTH: usize = 23;
@@ -29,10 +34,18 @@ pub const MESSAGE_NONCE_LENGTH: usize = 12;
2934
/// The Id nonce length (in bytes).
3035
pub const ID_NONCE_LENGTH: usize = 16;
3136

32-
/// Protocol ID sent with each message.
33-
const PROTOCOL_ID: &str = "discv5";
34-
/// The version sent with each handshake.
35-
const VERSION: u16 = 0x0001;
37+
/// Protocol ID bytes sent with each message.
38+
#[cfg(not(test))]
39+
pub(crate) static PROTOCOL_ID: OnceCell<[u8; PROTOCOL_ID_LENGTH]> = OnceCell::new();
40+
#[cfg(test)]
41+
pub(crate) static PROTOCOL_ID: OnceCell<[u8; PROTOCOL_ID_LENGTH]> =
42+
OnceCell::with_value(crate::config::DEFAULT_PROTOCOL_ID);
43+
/// The version bytes sent with each handshake.
44+
#[cfg(not(test))]
45+
pub(crate) static VERSION: OnceCell<[u8; PROTOCOL_VERSION_LENGTH]> = OnceCell::new();
46+
#[cfg(test)]
47+
pub(crate) static VERSION: OnceCell<[u8; PROTOCOL_VERSION_LENGTH]> =
48+
OnceCell::with_value(crate::config::DEFAULT_PROTOCOL_VERSION);
3649

3750
pub(crate) const MAX_PACKET_SIZE: usize = 1280;
3851
// The smallest packet must be at least this large
@@ -95,8 +108,8 @@ impl PacketHeader {
95108
pub fn encode(&self) -> Vec<u8> {
96109
let auth_data = self.kind.encode();
97110
let mut buf = Vec::with_capacity(auth_data.len() + STATIC_HEADER_LENGTH);
98-
buf.extend_from_slice(PROTOCOL_ID.as_bytes());
99-
buf.extend_from_slice(&VERSION.to_be_bytes());
111+
buf.extend_from_slice(PROTOCOL_ID.wait());
112+
buf.extend_from_slice(VERSION.wait());
100113
let kind: u8 = (&self.kind).into();
101114
buf.extend_from_slice(&kind.to_be_bytes());
102115
buf.extend_from_slice(&self.message_nonce);
@@ -440,17 +453,16 @@ impl Packet {
440453
}
441454

442455
// Check the protocol id
443-
if &static_header[..6] != PROTOCOL_ID.as_bytes() {
456+
if &static_header[..PROTOCOL_ID_LENGTH] != PROTOCOL_ID.wait() {
444457
return Err(PacketError::HeaderDecryptionFailed);
445458
}
446459

447460
// Check the version matches
448-
let version = u16::from_be_bytes(
449-
static_header[6..8]
450-
.try_into()
451-
.expect("Must be correct size"),
452-
);
453-
if version != VERSION {
461+
let version_bytes =
462+
&static_header[PROTOCOL_ID_LENGTH..PROTOCOL_ID_LENGTH + PROTOCOL_VERSION_LENGTH];
463+
if version_bytes != VERSION.wait() {
464+
let version =
465+
u16::from_be_bytes(version_bytes.try_into().expect("Must be correct size"));
454466
return Err(PacketError::InvalidVersion(version));
455467
}
456468

0 commit comments

Comments
 (0)