Skip to content

Commit 13a5948

Browse files
committed
extract sharding.rs
1 parent ab0a6f2 commit 13a5948

File tree

2 files changed

+199
-193
lines changed

2 files changed

+199
-193
lines changed

scylla/src/routing/mod.rs

Lines changed: 4 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
use rand::Rng;
2-
use std::collections::HashMap;
3-
use std::convert::TryFrom;
4-
use std::num::NonZeroU16;
5-
use thiserror::Error;
1+
mod sharding;
2+
3+
pub(crate) use sharding::ShardInfo;
4+
pub use sharding::{Shard, ShardCount, Sharder, ShardingError};
65

76
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
87

@@ -42,191 +41,3 @@ impl Token {
4241
self.value
4342
}
4443
}
45-
46-
pub type Shard = u32;
47-
pub type ShardCount = NonZeroU16;
48-
49-
#[derive(PartialEq, Eq, Clone, Debug)]
50-
pub(crate) struct ShardInfo {
51-
pub(crate) shard: u16,
52-
pub(crate) nr_shards: ShardCount,
53-
pub(crate) msb_ignore: u8,
54-
}
55-
56-
#[derive(PartialEq, Eq, Clone, Debug)]
57-
pub struct Sharder {
58-
pub nr_shards: ShardCount,
59-
pub msb_ignore: u8,
60-
}
61-
62-
impl std::str::FromStr for Token {
63-
type Err = std::num::ParseIntError;
64-
fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> {
65-
Ok(Token { value: s.parse()? })
66-
}
67-
}
68-
69-
impl ShardInfo {
70-
pub(crate) fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self {
71-
ShardInfo {
72-
shard,
73-
nr_shards,
74-
msb_ignore,
75-
}
76-
}
77-
78-
pub(crate) fn get_sharder(&self) -> Sharder {
79-
Sharder::new(self.nr_shards, self.msb_ignore)
80-
}
81-
}
82-
83-
impl Sharder {
84-
pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self {
85-
Sharder {
86-
nr_shards,
87-
msb_ignore,
88-
}
89-
}
90-
91-
pub fn shard_of(&self, token: Token) -> Shard {
92-
let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63);
93-
biased_token <<= self.msb_ignore;
94-
(((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard
95-
}
96-
97-
/// If we connect to Scylla using Scylla's shard aware port, then Scylla assigns a shard to the
98-
/// connection based on the source port. This calculates the assigned shard.
99-
pub fn shard_of_source_port(&self, source_port: u16) -> Shard {
100-
(source_port % self.nr_shards.get()) as Shard
101-
}
102-
103-
/// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`.
104-
pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 {
105-
assert!(shard < self.nr_shards.get() as u32);
106-
rand::thread_rng()
107-
.gen_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1))
108-
/ self.nr_shards.get()
109-
* self.nr_shards.get()
110-
+ shard as u16
111-
}
112-
113-
/// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`.
114-
/// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around.
115-
/// Stops once all possible ports have been returned
116-
pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> {
117-
assert!(shard < self.nr_shards.get() as u32);
118-
119-
// Randomly choose a port to start at
120-
let starting_port = self.draw_source_port_for_shard(shard);
121-
122-
// Choose smallest available port number to begin at after wrapping
123-
// apply the formula from draw_source_port_for_shard for lowest possible gen_range result
124-
let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get()
125-
* self.nr_shards.get()
126-
+ shard as u16;
127-
128-
let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into());
129-
let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into());
130-
131-
before_wrap.chain(after_wrap)
132-
}
133-
}
134-
135-
#[derive(Clone, Error, Debug)]
136-
pub enum ShardingError {
137-
#[error("ShardInfo parameters missing")]
138-
MissingShardInfoParameter,
139-
#[error("ShardInfo parameters missing after unwrapping")]
140-
MissingUnwrapedShardInfoParameter,
141-
#[error("ShardInfo contains an invalid number of shards (0)")]
142-
ZeroShards,
143-
#[error("ParseIntError encountered while getting ShardInfo")]
144-
ParseIntError(#[from] std::num::ParseIntError),
145-
}
146-
147-
impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo {
148-
type Error = ShardingError;
149-
fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> {
150-
let shard_entry = options.get("SCYLLA_SHARD");
151-
let nr_shards_entry = options.get("SCYLLA_NR_SHARDS");
152-
let msb_ignore_entry = options.get("SCYLLA_SHARDING_IGNORE_MSB");
153-
if shard_entry.is_none() || nr_shards_entry.is_none() || msb_ignore_entry.is_none() {
154-
return Err(ShardingError::MissingShardInfoParameter);
155-
}
156-
if shard_entry.unwrap().is_empty()
157-
|| nr_shards_entry.unwrap().is_empty()
158-
|| msb_ignore_entry.unwrap().is_empty()
159-
{
160-
return Err(ShardingError::MissingUnwrapedShardInfoParameter);
161-
}
162-
let shard = shard_entry.unwrap().first().unwrap().parse::<u16>()?;
163-
let nr_shards = nr_shards_entry.unwrap().first().unwrap().parse::<u16>()?;
164-
let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?;
165-
let msb_ignore = msb_ignore_entry.unwrap().first().unwrap().parse::<u8>()?;
166-
Ok(ShardInfo::new(shard, nr_shards, msb_ignore))
167-
}
168-
}
169-
170-
#[cfg(test)]
171-
mod tests {
172-
use crate::test_utils::setup_tracing;
173-
174-
use super::Token;
175-
use super::{ShardCount, Sharder};
176-
use std::collections::HashSet;
177-
178-
#[test]
179-
fn test_shard_of() {
180-
setup_tracing();
181-
/* Test values taken from the gocql driver. */
182-
let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12);
183-
assert_eq!(
184-
sharder.shard_of(Token {
185-
value: -9219783007514621794
186-
}),
187-
3
188-
);
189-
assert_eq!(
190-
sharder.shard_of(Token {
191-
value: 9222582454147032830
192-
}),
193-
3
194-
);
195-
}
196-
197-
#[test]
198-
fn test_iter_source_ports_for_shard() {
199-
setup_tracing();
200-
let nr_shards = 4;
201-
let max_port_num = 65535;
202-
let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards;
203-
204-
let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12);
205-
206-
// Test for each shard
207-
for shard in 0..nr_shards {
208-
// Find lowest port for this shard
209-
let mut lowest_port = min_port_num;
210-
while lowest_port % nr_shards != shard {
211-
lowest_port += 1;
212-
}
213-
214-
// Find total number of ports the iterator should return
215-
let possible_ports_number: usize =
216-
((max_port_num - lowest_port) / nr_shards + 1).into();
217-
218-
let port_iter = sharder.iter_source_ports_for_shard(shard.into());
219-
220-
let mut returned_ports: HashSet<u16> = HashSet::new();
221-
for port in port_iter {
222-
assert!(!returned_ports.contains(&port)); // No port occurs two times
223-
assert!(port % nr_shards == shard); // Each port maps to this shard
224-
225-
returned_ports.insert(port);
226-
}
227-
228-
// Numbers of ports returned matches the expected value
229-
assert_eq!(returned_ports.len(), possible_ports_number);
230-
}
231-
}
232-
}

0 commit comments

Comments
 (0)