Skip to content

Commit 7e47f67

Browse files
committed
add enforce_target_node on batch
1 parent 20c32b8 commit 7e47f67

File tree

5 files changed

+132
-5
lines changed

5 files changed

+132
-5
lines changed

scylla/src/statement/batch.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use std::sync::Arc;
33
use crate::history::HistoryListener;
44
use crate::retry_policy::RetryPolicy;
55
use crate::statement::{prepared_statement::PreparedStatement, query::Query};
6-
use crate::transport::execution_profile::ExecutionProfileHandle;
6+
use crate::transport::{execution_profile::ExecutionProfileHandle, Node};
7+
use crate::Session;
78

89
use super::StatementConfig;
910
pub use super::{Consistency, SerialConsistency};
@@ -144,6 +145,82 @@ impl Batch {
144145
pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> {
145146
self.config.execution_profile_handle.as_ref()
146147
}
148+
149+
/// Associates the batch with a new execution profile that will have a load balancing policy
150+
/// that will enforce the use of the provided [`Node`] to the extent possible.
151+
///
152+
/// This should typically be used in conjunction with [`Session::shard_for_statement`], where
153+
/// you would constitute a batch by assigning to the same batch all the statements that would be executed in
154+
/// the same shard.
155+
///
156+
/// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement
157+
/// to the same node, you should use this method to enforce the use of the original node that was envisioned by
158+
/// `shard_for_statement` for the batch:
159+
///
160+
/// ```rust
161+
/// # use scylla::Session;
162+
/// # use std::error::Error;
163+
/// # async fn check_only_compiles(session: &Session) -> Result<(), Box<dyn Error>> {
164+
/// use scylla::{
165+
/// batch::Batch,
166+
/// frame::value::{SerializedValues, ValueList},
167+
/// };
168+
///
169+
/// let prepared_statement = session
170+
/// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)")
171+
/// .await?;
172+
///
173+
/// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned();
174+
/// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?;
175+
///
176+
/// // Send that to a task that will handle statements targeted to the same shard
177+
///
178+
/// // On that task:
179+
/// // Constitute a batch with all the statements that would be executed in the same shard
180+
///
181+
/// let mut batch: Batch = Default::default();
182+
/// if let Some((node, _shard_idx)) = shard {
183+
/// batch.enforce_target_node(&node, &session);
184+
/// }
185+
/// let mut batch_values = Vec::new();
186+
///
187+
/// // As the task handling statements targeted to this shard receives them,
188+
/// // it appends them to the batch
189+
/// batch.append_statement(prepared_statement);
190+
/// batch_values.push(serialized_values);
191+
///
192+
/// // Run the batch
193+
/// session.batch(&batch, batch_values).await?;
194+
/// # Ok(())
195+
/// # }
196+
/// ```
197+
///
198+
///
199+
/// If the target node is not available anymore at the time of executing the statement, it will fallback to the
200+
/// original load balancing policy:
201+
/// - Either that currently set on the [`Batch`], if any
202+
/// - Or that of the [`Session`] if there isn't one on the `Batch`
203+
pub fn enforce_target_node<'a>(
204+
&mut self,
205+
node: &Arc<Node>,
206+
base_execution_profile_from_session: &Session,
207+
) {
208+
let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| {
209+
base_execution_profile_from_session.get_default_execution_profile_handle()
210+
});
211+
self.set_execution_profile_handle(Some(
212+
execution_profile_handle
213+
.pointee_to_builder()
214+
.load_balancing_policy(Arc::new(
215+
crate::load_balancing::EnforceTargetNodePolicy::new(
216+
node,
217+
execution_profile_handle.load_balancing_policy(),
218+
),
219+
))
220+
.build()
221+
.into_handle(),
222+
))
223+
}
147224
}
148225

149226
impl Default for Batch {

scylla/src/transport/execution_profile.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,4 +485,8 @@ impl ExecutionProfileHandle {
485485
pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) {
486486
self.0 .0.store(profile.0)
487487
}
488+
489+
pub fn load_balancing_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
490+
self.0 .0.load().load_balancing_policy.clone()
491+
}
488492
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo};
2+
use crate::transport::{cluster::ClusterData, Node};
3+
use std::sync::Arc;
4+
5+
#[derive(Debug)]
6+
pub struct EnforceTargetNodePolicy {
7+
target_node: uuid::Uuid,
8+
fallback: Arc<dyn LoadBalancingPolicy>,
9+
}
10+
11+
impl EnforceTargetNodePolicy {
12+
pub fn new(target_node: &Arc<Node>, fallback: Arc<dyn LoadBalancingPolicy>) -> Self {
13+
Self {
14+
target_node: target_node.host_id,
15+
fallback,
16+
}
17+
}
18+
}
19+
impl LoadBalancingPolicy for EnforceTargetNodePolicy {
20+
fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option<NodeRef<'a>> {
21+
cluster
22+
.known_peers
23+
.get(&self.target_node)
24+
.or_else(|| self.fallback.pick(query, cluster))
25+
}
26+
27+
fn fallback<'a>(
28+
&'a self,
29+
query: &'a RoutingInfo,
30+
cluster: &'a ClusterData,
31+
) -> FallbackPlan<'a> {
32+
self.fallback.fallback(query, cluster)
33+
}
34+
35+
fn name(&self) -> String {
36+
format!(
37+
"Enforce target node Load balancing policy - Node: {} - fallback: {}",
38+
self.target_node,
39+
self.fallback.name()
40+
)
41+
}
42+
}

scylla/src/transport/load_balancing/mod.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ use scylla_cql::{errors::QueryError, frame::types};
99
use std::time::Duration;
1010

1111
mod default;
12+
mod enforce_node;
1213
mod plan;
13-
pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder};
1414
pub use plan::Plan;
15+
pub use {
16+
default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder},
17+
enforce_node::EnforceTargetNodePolicy,
18+
};
1519

1620
/// Represents info about statement that can be used by load balancing policies.
1721
#[derive(Default, Clone, Debug)]

scylla/src/transport/session.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,11 +1812,11 @@ impl Session {
18121812
.map(|partition_key| prepared.get_partitioner_name().hash(&partition_key)))
18131813
}
18141814

1815-
/// Get the first node/shard that the load balancer would target if running this query
1815+
/// Get a node/shard that the load balancer would potentially target if running this query
18161816
///
1817-
/// This may help constituting shard-aware batches
1817+
/// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`])
18181818
#[allow(clippy::type_complexity)]
1819-
pub fn first_shard_for_statement(
1819+
pub fn shard_for_statement(
18201820
&self,
18211821
prepared: &PreparedStatement,
18221822
serialized_values: &SerializedValues,

0 commit comments

Comments
 (0)