|
1 | | -use rand::Rng; |
2 | | -use std::collections::HashMap; |
3 | | -use std::convert::TryFrom; |
4 | | -use std::num::NonZeroU16; |
5 | | -use thiserror::Error; |
| 1 | +mod sharding; |
| 2 | + |
| 3 | +pub(crate) use sharding::ShardInfo; |
| 4 | +pub use sharding::{Shard, ShardCount, Sharder, ShardingError}; |
6 | 5 |
|
7 | 6 | #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)] |
8 | 7 |
|
@@ -42,191 +41,3 @@ impl Token { |
42 | 41 | self.value |
43 | 42 | } |
44 | 43 | } |
45 | | - |
46 | | -pub type Shard = u32; |
47 | | -pub type ShardCount = NonZeroU16; |
48 | | - |
49 | | -#[derive(PartialEq, Eq, Clone, Debug)] |
50 | | -pub(crate) struct ShardInfo { |
51 | | - pub(crate) shard: u16, |
52 | | - pub(crate) nr_shards: ShardCount, |
53 | | - pub(crate) msb_ignore: u8, |
54 | | -} |
55 | | - |
56 | | -#[derive(PartialEq, Eq, Clone, Debug)] |
57 | | -pub struct Sharder { |
58 | | - pub nr_shards: ShardCount, |
59 | | - pub msb_ignore: u8, |
60 | | -} |
61 | | - |
62 | | -impl std::str::FromStr for Token { |
63 | | - type Err = std::num::ParseIntError; |
64 | | - fn from_str(s: &str) -> Result<Token, std::num::ParseIntError> { |
65 | | - Ok(Token { value: s.parse()? }) |
66 | | - } |
67 | | -} |
68 | | - |
69 | | -impl ShardInfo { |
70 | | - pub(crate) fn new(shard: u16, nr_shards: ShardCount, msb_ignore: u8) -> Self { |
71 | | - ShardInfo { |
72 | | - shard, |
73 | | - nr_shards, |
74 | | - msb_ignore, |
75 | | - } |
76 | | - } |
77 | | - |
78 | | - pub(crate) fn get_sharder(&self) -> Sharder { |
79 | | - Sharder::new(self.nr_shards, self.msb_ignore) |
80 | | - } |
81 | | -} |
82 | | - |
83 | | -impl Sharder { |
84 | | - pub fn new(nr_shards: ShardCount, msb_ignore: u8) -> Self { |
85 | | - Sharder { |
86 | | - nr_shards, |
87 | | - msb_ignore, |
88 | | - } |
89 | | - } |
90 | | - |
91 | | - pub fn shard_of(&self, token: Token) -> Shard { |
92 | | - let mut biased_token = (token.value as u64).wrapping_add(1u64 << 63); |
93 | | - biased_token <<= self.msb_ignore; |
94 | | - (((biased_token as u128) * (self.nr_shards.get() as u128)) >> 64) as Shard |
95 | | - } |
96 | | - |
97 | | - /// If we connect to Scylla using Scylla's shard aware port, then Scylla assigns a shard to the |
98 | | - /// connection based on the source port. This calculates the assigned shard. |
99 | | - pub fn shard_of_source_port(&self, source_port: u16) -> Shard { |
100 | | - (source_port % self.nr_shards.get()) as Shard |
101 | | - } |
102 | | - |
103 | | - /// Randomly choose a source port `p` such that `shard == shard_of_source_port(p)`. |
104 | | - pub fn draw_source_port_for_shard(&self, shard: Shard) -> u16 { |
105 | | - assert!(shard < self.nr_shards.get() as u32); |
106 | | - rand::thread_rng() |
107 | | - .gen_range((49152 + self.nr_shards.get() - 1)..(65535 - self.nr_shards.get() + 1)) |
108 | | - / self.nr_shards.get() |
109 | | - * self.nr_shards.get() |
110 | | - + shard as u16 |
111 | | - } |
112 | | - |
113 | | - /// Returns iterator over source ports `p` such that `shard == shard_of_source_port(p)`. |
114 | | - /// Starts at a random port and goes forward by `nr_shards`. After reaching maximum wraps back around. |
115 | | - /// Stops once all possible ports have been returned |
116 | | - pub fn iter_source_ports_for_shard(&self, shard: Shard) -> impl Iterator<Item = u16> { |
117 | | - assert!(shard < self.nr_shards.get() as u32); |
118 | | - |
119 | | - // Randomly choose a port to start at |
120 | | - let starting_port = self.draw_source_port_for_shard(shard); |
121 | | - |
122 | | - // Choose smallest available port number to begin at after wrapping |
123 | | - // apply the formula from draw_source_port_for_shard for lowest possible gen_range result |
124 | | - let first_valid_port = (49152 + self.nr_shards.get() - 1) / self.nr_shards.get() |
125 | | - * self.nr_shards.get() |
126 | | - + shard as u16; |
127 | | - |
128 | | - let before_wrap = (starting_port..=65535).step_by(self.nr_shards.get().into()); |
129 | | - let after_wrap = (first_valid_port..starting_port).step_by(self.nr_shards.get().into()); |
130 | | - |
131 | | - before_wrap.chain(after_wrap) |
132 | | - } |
133 | | -} |
134 | | - |
135 | | -#[derive(Clone, Error, Debug)] |
136 | | -pub enum ShardingError { |
137 | | - #[error("ShardInfo parameters missing")] |
138 | | - MissingShardInfoParameter, |
139 | | - #[error("ShardInfo parameters missing after unwrapping")] |
140 | | - MissingUnwrapedShardInfoParameter, |
141 | | - #[error("ShardInfo contains an invalid number of shards (0)")] |
142 | | - ZeroShards, |
143 | | - #[error("ParseIntError encountered while getting ShardInfo")] |
144 | | - ParseIntError(#[from] std::num::ParseIntError), |
145 | | -} |
146 | | - |
147 | | -impl<'a> TryFrom<&'a HashMap<String, Vec<String>>> for ShardInfo { |
148 | | - type Error = ShardingError; |
149 | | - fn try_from(options: &'a HashMap<String, Vec<String>>) -> Result<Self, Self::Error> { |
150 | | - let shard_entry = options.get("SCYLLA_SHARD"); |
151 | | - let nr_shards_entry = options.get("SCYLLA_NR_SHARDS"); |
152 | | - let msb_ignore_entry = options.get("SCYLLA_SHARDING_IGNORE_MSB"); |
153 | | - if shard_entry.is_none() || nr_shards_entry.is_none() || msb_ignore_entry.is_none() { |
154 | | - return Err(ShardingError::MissingShardInfoParameter); |
155 | | - } |
156 | | - if shard_entry.unwrap().is_empty() |
157 | | - || nr_shards_entry.unwrap().is_empty() |
158 | | - || msb_ignore_entry.unwrap().is_empty() |
159 | | - { |
160 | | - return Err(ShardingError::MissingUnwrapedShardInfoParameter); |
161 | | - } |
162 | | - let shard = shard_entry.unwrap().first().unwrap().parse::<u16>()?; |
163 | | - let nr_shards = nr_shards_entry.unwrap().first().unwrap().parse::<u16>()?; |
164 | | - let nr_shards = ShardCount::new(nr_shards).ok_or(ShardingError::ZeroShards)?; |
165 | | - let msb_ignore = msb_ignore_entry.unwrap().first().unwrap().parse::<u8>()?; |
166 | | - Ok(ShardInfo::new(shard, nr_shards, msb_ignore)) |
167 | | - } |
168 | | -} |
169 | | - |
170 | | -#[cfg(test)] |
171 | | -mod tests { |
172 | | - use crate::test_utils::setup_tracing; |
173 | | - |
174 | | - use super::Token; |
175 | | - use super::{ShardCount, Sharder}; |
176 | | - use std::collections::HashSet; |
177 | | - |
178 | | - #[test] |
179 | | - fn test_shard_of() { |
180 | | - setup_tracing(); |
181 | | - /* Test values taken from the gocql driver. */ |
182 | | - let sharder = Sharder::new(ShardCount::new(4).unwrap(), 12); |
183 | | - assert_eq!( |
184 | | - sharder.shard_of(Token { |
185 | | - value: -9219783007514621794 |
186 | | - }), |
187 | | - 3 |
188 | | - ); |
189 | | - assert_eq!( |
190 | | - sharder.shard_of(Token { |
191 | | - value: 9222582454147032830 |
192 | | - }), |
193 | | - 3 |
194 | | - ); |
195 | | - } |
196 | | - |
197 | | - #[test] |
198 | | - fn test_iter_source_ports_for_shard() { |
199 | | - setup_tracing(); |
200 | | - let nr_shards = 4; |
201 | | - let max_port_num = 65535; |
202 | | - let min_port_num = (49152 + nr_shards - 1) / nr_shards * nr_shards; |
203 | | - |
204 | | - let sharder = Sharder::new(ShardCount::new(nr_shards).unwrap(), 12); |
205 | | - |
206 | | - // Test for each shard |
207 | | - for shard in 0..nr_shards { |
208 | | - // Find lowest port for this shard |
209 | | - let mut lowest_port = min_port_num; |
210 | | - while lowest_port % nr_shards != shard { |
211 | | - lowest_port += 1; |
212 | | - } |
213 | | - |
214 | | - // Find total number of ports the iterator should return |
215 | | - let possible_ports_number: usize = |
216 | | - ((max_port_num - lowest_port) / nr_shards + 1).into(); |
217 | | - |
218 | | - let port_iter = sharder.iter_source_ports_for_shard(shard.into()); |
219 | | - |
220 | | - let mut returned_ports: HashSet<u16> = HashSet::new(); |
221 | | - for port in port_iter { |
222 | | - assert!(!returned_ports.contains(&port)); // No port occurs two times |
223 | | - assert!(port % nr_shards == shard); // Each port maps to this shard |
224 | | - |
225 | | - returned_ports.insert(port); |
226 | | - } |
227 | | - |
228 | | - // Numbers of ports returned matches the expected value |
229 | | - assert_eq!(returned_ports.len(), possible_ports_number); |
230 | | - } |
231 | | - } |
232 | | -} |
0 commit comments