Skip to content

Commit 6ecc444

Browse files
committed
LBP: Return Option<Shard> instead of Shard
This was already documented as such, but due to an oversight the code was in disagreement with documentation. Approach from the documentation is better, because the currently implemented approach prevented deduplication in Plan from working correctly.
1 parent cf0b1cd commit 6ecc444

File tree

7 files changed

+252
-138
lines changed

7 files changed

+252
-138
lines changed

examples/custom_load_balancing_policy.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,20 @@ struct CustomLoadBalancingPolicy {
1818
fav_datacenter_name: String,
1919
}
2020

21-
fn with_random_shard(node: NodeRef) -> (NodeRef, Shard) {
21+
fn with_random_shard(node: NodeRef) -> (NodeRef, Option<Shard>) {
2222
let nr_shards = node
2323
.sharder()
2424
.map(|sharder| sharder.nr_shards.get())
2525
.unwrap_or(1);
26-
(node, thread_rng().gen_range(0..nr_shards) as Shard)
26+
(node, Some(thread_rng().gen_range(0..nr_shards) as Shard))
2727
}
2828

2929
impl LoadBalancingPolicy for CustomLoadBalancingPolicy {
3030
fn pick<'a>(
3131
&'a self,
3232
_info: &'a RoutingInfo,
3333
cluster: &'a ClusterData,
34-
) -> Option<(NodeRef<'a>, Shard)> {
34+
) -> Option<(NodeRef<'a>, Option<Shard>)> {
3535
self.fallback(_info, cluster).next()
3636
}
3737

scylla/src/transport/load_balancing/default.rs

Lines changed: 161 additions & 113 deletions
Large diffs are not rendered by default.

scylla/src/transport/load_balancing/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ pub struct RoutingInfo<'a> {
3939
///
4040
/// It is computed on-demand, only if querying the most preferred node fails
4141
/// (or when speculative execution is triggered).
42-
pub type FallbackPlan<'a> = Box<dyn Iterator<Item = (NodeRef<'a>, Shard)> + Send + Sync + 'a>;
42+
pub type FallbackPlan<'a> =
43+
Box<dyn Iterator<Item = (NodeRef<'a>, Option<Shard>)> + Send + Sync + 'a>;
4344

4445
/// Policy that decides which nodes and shards to contact for each query.
4546
///
@@ -67,7 +68,7 @@ pub trait LoadBalancingPolicy: Send + Sync + std::fmt::Debug {
6768
&'a self,
6869
query: &'a RoutingInfo,
6970
cluster: &'a ClusterData,
70-
) -> Option<(NodeRef<'a>, Shard)>;
71+
) -> Option<(NodeRef<'a>, Option<Shard>)>;
7172

7273
/// Returns all contact-appropriate nodes for a given query.
7374
fn fallback<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData)

scylla/src/transport/load_balancing/plan.rs

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use rand::{thread_rng, Rng};
12
use tracing::error;
23

34
use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
@@ -6,20 +7,65 @@ use crate::{routing::Shard, transport::ClusterData};
67
enum PlanState<'a> {
78
Created,
89
PickedNone, // This always means an abnormal situation: it means that no nodes satisfied locality/node filter requirements.
9-
Picked((NodeRef<'a>, Shard)),
10+
Picked((NodeRef<'a>, Option<Shard>)),
1011
Fallback {
1112
iter: FallbackPlan<'a>,
12-
node_to_filter_out: (NodeRef<'a>, Shard),
13+
target_to_filter_out: (NodeRef<'a>, Option<Shard>),
1314
},
1415
}
1516

16-
/// The list of nodes constituting the query plan.
17+
/// The list of targets constituting the query plan. Target here is a pair `(NodeRef<'a>, Shard)`.
1718
///
18-
/// The plan is partly lazily computed, with the first node computed
19-
/// eagerly in the first place and the remaining nodes computed on-demand
19+
/// The plan is partly lazily computed, with the first target computed
20+
/// eagerly in the first place and the remaining targets computed on-demand
2021
/// (all at once).
2122
/// This significantly reduces the allocation overhead on "the happy path"
22-
/// (when the first node successfully handles the request),
23+
/// (when the first target successfully handles the request).
24+
///
25+
/// `Plan` implements `Iterator<Item=(NodeRef<'a>, Shard)>` but LoadBalancingPolicy
26+
/// returns `Option<Shard>` instead of `Shard` both in `pick` and in `fallback`.
27+
/// `Plan` handles the `None` case by using random shard for a given node.
28+
/// There is currently no way to configure RNG used by `Plan`.
29+
/// If you don't want `Plan` to do randomize shards or you want to control the RNG,
30+
/// use custom LBP that will always return non-`None` shards.
31+
/// Example of LBP that always uses shard 0, preventing `Plan` from using random numbers:
32+
///
33+
/// ```
34+
/// # use std::sync::Arc;
35+
/// # use scylla::load_balancing::LoadBalancingPolicy;
36+
/// # use scylla::load_balancing::RoutingInfo;
37+
/// # use scylla::transport::ClusterData;
38+
/// # use scylla::transport::NodeRef;
39+
/// # use scylla::routing::Shard;
40+
/// # use scylla::load_balancing::FallbackPlan;
41+
///
42+
/// #[derive(Debug)]
43+
/// struct NonRandomLBP {
44+
/// inner: Arc<dyn LoadBalancingPolicy>,
45+
/// }
46+
/// impl LoadBalancingPolicy for NonRandomLBP {
47+
/// fn pick<'a>(
48+
/// &'a self,
49+
/// info: &'a RoutingInfo,
50+
/// cluster: &'a ClusterData,
51+
/// ) -> Option<(NodeRef<'a>, Option<Shard>)> {
52+
/// self.inner
53+
/// .pick(info, cluster)
54+
/// .map(|(node, shard)| (node, shard.or(Some(0))))
55+
/// }
56+
///
57+
/// fn fallback<'a>(&'a self, info: &'a RoutingInfo, cluster: &'a ClusterData) -> FallbackPlan<'a> {
58+
/// Box::new(self.inner
59+
/// .fallback(info, cluster)
60+
/// .map(|(node, shard)| (node, shard.or(Some(0)))))
61+
/// }
62+
///
63+
/// fn name(&self) -> String {
64+
/// "NonRandomLBP".to_string()
65+
/// }
66+
/// }
67+
/// ```
68+
2369
pub struct Plan<'a> {
2470
policy: &'a dyn LoadBalancingPolicy,
2571
routing_info: &'a RoutingInfo<'a>,
@@ -41,6 +87,21 @@ impl<'a> Plan<'a> {
4187
state: PlanState::Created,
4288
}
4389
}
90+
91+
fn with_random_shard_if_unknown(
92+
(node, shard): (NodeRef<'_>, Option<Shard>),
93+
) -> (NodeRef<'_>, Shard) {
94+
(
95+
node,
96+
shard.unwrap_or_else(|| {
97+
let nr_shards = node
98+
.sharder()
99+
.map(|sharder| sharder.nr_shards.get())
100+
.unwrap_or(1);
101+
thread_rng().gen_range(0..nr_shards).into()
102+
}),
103+
)
104+
}
44105
}
45106

46107
impl<'a> Iterator for Plan<'a> {
@@ -52,7 +113,7 @@ impl<'a> Iterator for Plan<'a> {
52113
let picked = self.policy.pick(self.routing_info, self.cluster);
53114
if let Some(picked) = picked {
54115
self.state = PlanState::Picked(picked);
55-
Some(picked)
116+
Some(Self::with_random_shard_if_unknown(picked))
56117
} else {
57118
// `pick()` returned None, which semantically means that a first node cannot be computed _cheaply_.
58119
// This, however, does not imply that fallback would return an empty plan, too.
@@ -64,9 +125,9 @@ impl<'a> Iterator for Plan<'a> {
64125
if let Some(node) = first_fallback_node {
65126
self.state = PlanState::Fallback {
66127
iter,
67-
node_to_filter_out: node,
128+
target_to_filter_out: node,
68129
};
69-
Some(node)
130+
Some(Self::with_random_shard_if_unknown(node))
70131
} else {
71132
error!("Load balancing policy returned an empty plan! The query cannot be executed. Routing info: {:?}", self.routing_info);
72133
self.state = PlanState::PickedNone;
@@ -77,20 +138,20 @@ impl<'a> Iterator for Plan<'a> {
77138
PlanState::Picked(node) => {
78139
self.state = PlanState::Fallback {
79140
iter: self.policy.fallback(self.routing_info, self.cluster),
80-
node_to_filter_out: *node,
141+
target_to_filter_out: *node,
81142
};
82143

83144
self.next()
84145
}
85146
PlanState::Fallback {
86147
iter,
87-
node_to_filter_out,
148+
target_to_filter_out: node_to_filter_out,
88149
} => {
89150
for node in iter {
90151
if node == *node_to_filter_out {
91152
continue;
92153
} else {
93-
return Some(node);
154+
return Some(Self::with_random_shard_if_unknown(node));
94155
}
95156
}
96157

@@ -135,7 +196,7 @@ mod tests {
135196
&'a self,
136197
_query: &'a RoutingInfo,
137198
_cluster: &'a ClusterData,
138-
) -> Option<(NodeRef<'a>, Shard)> {
199+
) -> Option<(NodeRef<'a>, Option<Shard>)> {
139200
None
140201
}
141202

@@ -147,7 +208,7 @@ mod tests {
147208
Box::new(
148209
self.expected_nodes
149210
.iter()
150-
.map(|(node_ref, shard)| (node_ref, *shard)),
211+
.map(|(node_ref, shard)| (node_ref, Some(*shard))),
151212
)
152213
}
153214

scylla/tests/integration/consistency.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ impl LoadBalancingPolicy for RoutingInfoReportingWrapper {
379379
&'a self,
380380
query: &'a RoutingInfo,
381381
cluster: &'a scylla::transport::ClusterData,
382-
) -> Option<(NodeRef<'a>, Shard)> {
382+
) -> Option<(NodeRef<'a>, Option<Shard>)> {
383383
self.routing_info_tx
384384
.send(OwnedRoutingInfo::from(query.clone()))
385385
.unwrap();

scylla/tests/integration/execution_profiles.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,13 @@ impl<const NODE: u8> LoadBalancingPolicy for BoundToPredefinedNodePolicy<NODE> {
5151
&'a self,
5252
_info: &'a RoutingInfo,
5353
cluster: &'a ClusterData,
54-
) -> Option<(NodeRef<'a>, Shard)> {
54+
) -> Option<(NodeRef<'a>, Option<Shard>)> {
5555
self.report_node(Report::LoadBalancing);
56-
cluster.get_nodes_info().iter().next().map(|node| (node, 0))
56+
cluster
57+
.get_nodes_info()
58+
.iter()
59+
.next()
60+
.map(|node| (node, None))
5761
}
5862

5963
fn fallback<'a>(

scylla/tests/integration/utils.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ pub(crate) fn setup_tracing() {
1919
.try_init();
2020
}
2121

22-
fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Shard) {
22+
fn with_pseudorandom_shard(node: NodeRef) -> (NodeRef, Option<Shard>) {
2323
let nr_shards = node
2424
.sharder()
2525
.map(|sharder| sharder.nr_shards.get())
2626
.unwrap_or(1);
27-
(node, ((nr_shards - 1) % 42) as Shard)
27+
(node, Some(((nr_shards - 1) % 42) as Shard))
2828
}
2929

3030
#[derive(Debug)]
@@ -34,7 +34,7 @@ impl LoadBalancingPolicy for FixedOrderLoadBalancer {
3434
&'a self,
3535
_info: &'a scylla::load_balancing::RoutingInfo,
3636
cluster: &'a scylla::transport::ClusterData,
37-
) -> Option<(NodeRef<'a>, Shard)> {
37+
) -> Option<(NodeRef<'a>, Option<Shard>)> {
3838
cluster
3939
.get_nodes_info()
4040
.iter()

0 commit comments

Comments
 (0)