Skip to content

Commit e6ea580

Browse files
authored
Merge pull request #1326 from muzarski/enforce-request-coordinator
Enforce request coordinator
2 parents 78bbc93 + 9647bf4 commit e6ea580

File tree

13 files changed

+427
-65
lines changed

13 files changed

+427
-65
lines changed

docs/source/load-balancing/load-balancing.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,12 @@ awareness enabled and latency-awareness disabled.
3636

3737
## Configuration
3838

39-
Load balancing policies can be configured via execution profiles. In the code
40-
sample provided, a new execution profile is created using
39+
Load balancing policies can be configured in three different ways (sorted by descending precedence):
40+
1. directly on the `Statement`, `PreparedStatement` or `Batch` (`Statement::set_load_balancing_policy()`)
41+
2. execution profile set on the statement (`Statement::set_execution_profile_handle()`)
42+
3. default execution profile set on the session (`SessionBuilder::default_execution_profile_handle()`)
43+
44+
In the code sample provided, a new execution profile is created using
4145
`ExecutionProfile::builder()`, and the load balancing policy is set to the
4246
`DefaultPolicy` using `.load_balancing_policy(policy)`.
4347

examples/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,10 @@ path = "get_by_name.rs"
9898
name = "value_list"
9999
path = "value_list.rs"
100100

101+
[[example]]
102+
name = "enforce_coordinator"
103+
path = "enforce_coordinator.rs"
104+
101105
[[example]]
102106
name = "custom_load_balancing_policy"
103107
path = "custom_load_balancing_policy.rs"

examples/enforce_coordinator.rs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
//! This example show how to enforce the target the request is sent to.
2+
3+
use std::net::{IpAddr, SocketAddr};
4+
use std::sync::Arc;
5+
6+
use anyhow::Result;
7+
use scylla::client::session::Session;
8+
use scylla::client::session_builder::SessionBuilder;
9+
use scylla::cluster::Node;
10+
use scylla::policies::load_balancing::{NodeIdentifier, SingleTargetLoadBalancingPolicy};
11+
use scylla::statement::prepared::PreparedStatement;
12+
13+
/// Executes "SELECT host_id, rpc_address FROM system.local" query with `node` as the enforced target.
14+
/// Checks whether the result matches the expected values (i.e. ones stored in peers metadata).
15+
async fn query_system_local_and_verify(
16+
session: &Session,
17+
node: &Arc<Node>,
18+
query_local: &PreparedStatement,
19+
) {
20+
let (actual_host_id, actual_node_ip) = session
21+
.execute_unpaged(query_local, ())
22+
.await
23+
.unwrap()
24+
.into_rows_result()
25+
.unwrap()
26+
.single_row::<(uuid::Uuid, IpAddr)>()
27+
.unwrap();
28+
29+
println!(
30+
"queried host_id: {}; queried node_ip: {}",
31+
actual_host_id, actual_node_ip
32+
);
33+
assert_eq!(node.host_id, actual_host_id);
34+
assert_eq!(node.address.ip(), actual_node_ip);
35+
}
36+
37+
#[tokio::main]
38+
async fn main() -> Result<()> {
39+
let uri = std::env::var("SCYLLA_URI").unwrap_or_else(|_| "127.0.0.1:9042".to_string());
40+
41+
let session: Session = SessionBuilder::new().known_node(uri).build().await?;
42+
43+
let state = session.get_cluster_state();
44+
let node = state
45+
.get_nodes_info()
46+
.first()
47+
.ok_or_else(|| anyhow::anyhow!("No nodes in metadata!"))?;
48+
49+
let expected_host_id = node.host_id;
50+
let expected_node_ip = node.address.ip();
51+
52+
let mut query_local = session
53+
.prepare("SELECT host_id, rpc_address FROM system.local where key='local'")
54+
.await?;
55+
56+
// Enforce the node using `Arc<Node>`.
57+
{
58+
let node_identifier = NodeIdentifier::Node(Arc::clone(node));
59+
println!("Enforcing target using {:?}...", node_identifier);
60+
query_local.set_load_balancing_policy(Some(SingleTargetLoadBalancingPolicy::new(
61+
node_identifier,
62+
None,
63+
)));
64+
65+
query_system_local_and_verify(&session, node, &query_local).await;
66+
}
67+
68+
// Enforce the node using host_id.
69+
{
70+
let node_identifier = NodeIdentifier::HostId(expected_host_id);
71+
println!("Enforcing target using {:?}...", node_identifier);
72+
query_local.set_load_balancing_policy(Some(SingleTargetLoadBalancingPolicy::new(
73+
node_identifier,
74+
None,
75+
)));
76+
77+
query_system_local_and_verify(&session, node, &query_local).await;
78+
}
79+
80+
// Enforce the node using **untranslated** node address.
81+
{
82+
let node_identifier =
83+
NodeIdentifier::NodeAddress(SocketAddr::new(expected_node_ip, node.address.port()));
84+
println!("Enforcing target using {:?}...", node_identifier);
85+
query_local.set_load_balancing_policy(Some(SingleTargetLoadBalancingPolicy::new(
86+
node_identifier,
87+
None,
88+
)));
89+
90+
query_system_local_and_verify(&session, node, &query_local).await;
91+
}
92+
93+
Ok(())
94+
}

scylla/src/client/pager.rs

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::observability::driver_tracing::RequestSpan;
3434
use crate::observability::history::{self, HistoryListener};
3535
#[cfg(feature = "metrics")]
3636
use crate::observability::metrics::Metrics;
37-
use crate::policies::load_balancing::{self, RoutingInfo};
37+
use crate::policies::load_balancing::{self, LoadBalancingPolicy, RoutingInfo};
3838
use crate::policies::retry::{RequestInfo, RetryDecision, RetrySession};
3939
use crate::response::query_result::ColumnSpecs;
4040
use crate::response::{NonErrorQueryResponse, QueryResponse};
@@ -136,11 +136,11 @@ struct PagerWorker<'a, QueryFunc, SpanCreatorFunc> {
136136
// AsyncFn(Arc<Connection>, Option<Arc<[u8]>>) -> Result<QueryResponse, RequestAttemptError>
137137
page_query: QueryFunc,
138138

139+
load_balancing_policy: Arc<dyn LoadBalancingPolicy>,
139140
statement_info: RoutingInfo<'a>,
140141
query_is_idempotent: bool,
141142
query_consistency: Consistency,
142143
retry_session: Box<dyn RetrySession>,
143-
execution_profile: Arc<ExecutionProfileInner>,
144144
#[cfg(feature = "metrics")]
145145
metrics: Arc<Metrics>,
146146

@@ -162,7 +162,7 @@ where
162162
{
163163
// Contract: this function MUST send at least one item through self.sender
164164
async fn work(mut self, cluster_state: Arc<ClusterState>) -> PageSendAttemptedProof {
165-
let load_balancer = self.execution_profile.load_balancing_policy.clone();
165+
let load_balancer = Arc::clone(&self.load_balancing_policy);
166166
let statement_info = self.statement_info.clone();
167167
let query_plan =
168168
load_balancing::Plan::new(load_balancer.as_ref(), &statement_info, &cluster_state);
@@ -335,8 +335,7 @@ where
335335
let _ = self.metrics.log_query_latency(elapsed.as_millis() as u64);
336336
self.log_attempt_success();
337337
self.log_request_success();
338-
self.execution_profile
339-
.load_balancing_policy
338+
self.load_balancing_policy
340339
.on_request_success(&self.statement_info, elapsed, node);
341340

342341
request_span.record_raw_rows_fields(&rows);
@@ -369,9 +368,12 @@ where
369368
Err(err) => {
370369
#[cfg(feature = "metrics")]
371370
self.metrics.inc_failed_paged_queries();
372-
self.execution_profile
373-
.load_balancing_policy
374-
.on_request_failure(&self.statement_info, elapsed, node, &err);
371+
self.load_balancing_policy.on_request_failure(
372+
&self.statement_info,
373+
elapsed,
374+
node,
375+
&err,
376+
);
375377
Err(err)
376378
}
377379
Ok(NonErrorQueryResponse {
@@ -391,9 +393,12 @@ where
391393
self.metrics.inc_failed_paged_queries();
392394
let err =
393395
RequestAttemptError::UnexpectedResponse(response.response.to_response_kind());
394-
self.execution_profile
395-
.load_balancing_policy
396-
.on_request_failure(&self.statement_info, elapsed, node, &err);
396+
self.load_balancing_policy.on_request_failure(
397+
&self.statement_info,
398+
elapsed,
399+
node,
400+
&err,
401+
);
397402
Err(err)
398403
}
399404
}
@@ -693,6 +698,12 @@ impl QueryPager {
693698
..Default::default()
694699
};
695700

701+
let load_balancing_policy = Arc::clone(
702+
statement
703+
.get_load_balancing_policy()
704+
.unwrap_or(&execution_profile.load_balancing_policy),
705+
);
706+
696707
let retry_session = statement
697708
.get_retry_policy()
698709
.map(|rp| &**rp)
@@ -733,8 +744,8 @@ impl QueryPager {
733744
statement_info: routing_info,
734745
query_is_idempotent: statement.config.is_idempotent,
735746
query_consistency: consistency,
747+
load_balancing_policy,
736748
retry_session,
737-
execution_profile,
738749
#[cfg(feature = "metrics")]
739750
metrics,
740751
paging_state: PagingState::start(),
@@ -769,6 +780,13 @@ impl QueryPager {
769780

770781
let page_size = config.prepared.get_validated_page_size();
771782

783+
let load_balancing_policy = Arc::clone(
784+
config
785+
.prepared
786+
.get_load_balancing_policy()
787+
.unwrap_or(&config.execution_profile.load_balancing_policy),
788+
);
789+
772790
let retry_session = config
773791
.prepared
774792
.get_retry_policy()
@@ -854,8 +872,8 @@ impl QueryPager {
854872
statement_info,
855873
query_is_idempotent: config.prepared.config.is_idempotent,
856874
query_consistency: consistency,
875+
load_balancing_policy,
857876
retry_session,
858-
execution_profile: config.execution_profile,
859877
#[cfg(feature = "metrics")]
860878
metrics: config.metrics,
861879
paging_state: PagingState::start(),

scylla/src/client/session.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,12 +1801,15 @@ impl Session {
18011801
.as_ref()
18021802
.map(|hl| (&**hl, hl.log_request_start()));
18031803

1804-
let load_balancer = &execution_profile.load_balancing_policy;
1804+
let load_balancer = statement_config
1805+
.load_balancing_policy
1806+
.as_deref()
1807+
.unwrap_or(execution_profile.load_balancing_policy.as_ref());
18051808

18061809
let runner = async {
18071810
let cluster_state = self.cluster.get_state();
18081811
let request_plan =
1809-
load_balancing::Plan::new(load_balancer.as_ref(), &statement_info, &cluster_state);
1812+
load_balancing::Plan::new(load_balancer, &statement_info, &cluster_state);
18101813

18111814
// If a speculative execution policy is used to run request, request_plan has to be shared
18121815
// between different async functions. This struct helps to wrap request_plan in mutex so it
@@ -1874,6 +1877,7 @@ impl Session {
18741877
consistency_set_on_statement: statement_config.consistency,
18751878
retry_session: retry_policy.new_session(),
18761879
history_data,
1880+
load_balancing_policy: load_balancer,
18771881
query_info: &statement_info,
18781882
request_span,
18791883
},
@@ -1910,6 +1914,7 @@ impl Session {
19101914
consistency_set_on_statement: statement_config.consistency,
19111915
retry_session: retry_policy.new_session(),
19121916
history_data,
1917+
load_balancing_policy: load_balancer,
19131918
query_info: &statement_info,
19141919
request_span,
19151920
},
@@ -2007,7 +2012,7 @@ impl Session {
20072012
#[cfg(feature = "metrics")]
20082013
let _ = self.metrics.log_query_latency(elapsed.as_millis() as u64);
20092014
context.log_attempt_success(&attempt_id);
2010-
execution_profile.load_balancing_policy.on_request_success(
2015+
context.load_balancing_policy.on_request_success(
20112016
context.query_info,
20122017
elapsed,
20132018
node,
@@ -2022,7 +2027,7 @@ impl Session {
20222027
);
20232028
#[cfg(feature = "metrics")]
20242029
self.metrics.inc_failed_nonpaged_queries();
2025-
execution_profile.load_balancing_policy.on_request_failure(
2030+
context.load_balancing_policy.on_request_failure(
20262031
context.query_info,
20272032
elapsed,
20282033
node,
@@ -2135,6 +2140,7 @@ struct ExecuteRequestContext<'a> {
21352140
consistency_set_on_statement: Option<Consistency>,
21362141
retry_session: Box<dyn RetrySession>,
21372142
history_data: Option<HistoryData<'a>>,
2143+
load_balancing_policy: &'a dyn load_balancing::LoadBalancingPolicy,
21382144
query_info: &'a load_balancing::RoutingInfo<'a>,
21392145
request_span: &'a RequestSpan,
21402146
}

scylla/src/policies/load_balancing/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ use std::time::Duration;
1313

1414
mod default;
1515
mod plan;
16+
mod single_target;
1617
pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder};
1718
pub use plan::Plan;
19+
pub use single_target::{NodeIdentifier, SingleTargetLoadBalancingPolicy};
1820

1921
/// Represents info about statement that can be used by load balancing policies.
2022
#[derive(Default, Clone, Debug)]

0 commit comments

Comments
 (0)