From 738b1e1c3095ad2888109754a3174a3f71a25639 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 3 Jun 2023 20:06:50 +0200 Subject: [PATCH 01/14] Add Session::first_shard_for_statement MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Resolves #468 This is a follow-up on #508 and #658: - To minimize CPU usage related to network operations when inserting a very large number of lines, it is relevant to batch. - To batch in the most efficient manner, these batches have to be shard-aware. Since #508, `batch` will pick the shard of the first statement to send the query to. However it is left to the user to constitute the batches in such a way that the target shard is the same for all the elements of the batch. - This was made *possible* by #658, but it was still very boilerplate-ish. I was waiting for #612 to be merged (amazing work btw! 😃) to implement a more direct and factored API (as that would use it). - This new ~`Session::first_shard_for_statement(self, &PreparedStatement, &SerializedValues) -> Option<(Node, Option)>` makes shard-aware batching easy on the users, by providing access to the first node and shard of the query plan. --- scylla/src/transport/session.rs | 75 ++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 16 deletions(-) diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 783124412e..73409f9341 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -8,6 +8,7 @@ use crate::frame::types::LegacyConsistency; use crate::history; use crate::history::HistoryListener; use crate::retry_policy::RetryPolicy; +use crate::routing; use arc_swap::ArcSwapOption; use async_trait::async_trait; use bytes::Bytes; @@ -898,15 +899,7 @@ impl Session { .as_ref() .map(|pk| prepared.get_partitioner_name().hash(pk)); - let statement_info = RoutingInfo { - consistency: prepared - .get_consistency() - .unwrap_or(self.default_execution_profile_handle.access().consistency), - serial_consistency: prepared.get_serial_consistency(), - token, - keyspace: prepared.get_keyspace_name(), - is_confirmed_lwt: prepared.is_confirmed_lwt(), - }; + let statement_info = self.routing_info(prepared, token); let span = RequestSpan::new_prepared(partition_key.as_ref(), token, serialized_values.size()); @@ -1814,13 +1807,63 @@ impl Session { prepared: &PreparedStatement, serialized_values: &SerializedValues, ) -> Result, QueryError> { - match self.calculate_partition_key(prepared, serialized_values) { - Ok(Some(partition_key)) => { - let partitioner_name = prepared.get_partitioner_name(); - Ok(Some(partitioner_name.hash(&partition_key))) - } - Ok(None) => Ok(None), - Err(err) => Err(err), + Ok(self + .calculate_partition_key(prepared, serialized_values)? + .map(|partition_key| prepared.get_partitioner_name().hash(&partition_key))) + } + + /// Get the first node/shard that the load balancer would target if running this query + /// + /// This may help constituting shard-aware batches + pub fn first_shard_for_statement( + &self, + prepared: &PreparedStatement, + serialized_values: &SerializedValues, + ) -> Result, Option)>, QueryError> { + let token = match self.calculate_token(prepared, serialized_values)? { + Some(token) => token, + None => return Ok(None), + }; + let routing_info = self.routing_info(prepared, Some(token)); + let cluster_data = self.cluster.get_data(); + let execution_profile = prepared + .config + .execution_profile_handle + .as_ref() + .unwrap_or_else(|| self.get_default_execution_profile_handle()) + .access(); + let mut query_plan = load_balancing::Plan::new( + &*execution_profile.load_balancing_policy, + &routing_info, + &cluster_data, + ); + // We can't return the full iterator here because the iterator borrows from local variables. + // In order to achieve that, two designs would be possible: + // - Construct a self-referential struct and implement iterator on it via e.g. Ouroboros + // - Take a closure as a parameter that will take the local iterator and return anything, and + // this function would return directly what the closure returns + // Most likely though, people would use this for some kind of shard-awareness optimization for batching, + // and are consequently not interested in subsequent nodes. + // Until then, let's just expose this, as it is simpler + Ok(query_plan.next().map(move |node| { + let token = node.sharder().map(|sharder| sharder.shard_of(token)); + (node.clone(), token) + })) + } + + fn routing_info<'p>( + &self, + prepared: &'p PreparedStatement, + token: Option, + ) -> RoutingInfo<'p> { + RoutingInfo { + consistency: prepared + .get_consistency() + .unwrap_or(self.default_execution_profile_handle.access().consistency), + serial_consistency: prepared.get_serial_consistency(), + token, + keyspace: prepared.get_keyspace_name(), + is_confirmed_lwt: prepared.is_confirmed_lwt(), } } From 7ac1759b84489f4175874c3043c8074a292a24ec Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sun, 4 Jun 2023 00:39:21 +0200 Subject: [PATCH 02/14] silence type-complexity lint on first_shard_for_statement --- scylla/src/transport/session.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 73409f9341..fb003b7538 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -1815,6 +1815,7 @@ impl Session { /// Get the first node/shard that the load balancer would target if running this query /// /// This may help constituting shard-aware batches + #[allow(clippy::type_complexity)] pub fn first_shard_for_statement( &self, prepared: &PreparedStatement, From 20c32b8bacbfef22ac8c635fb717bde792e37628 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sun, 2 Jul 2023 12:10:17 +0200 Subject: [PATCH 03/14] routing_info -> routing_info_from_prepared_statement --- scylla/src/transport/session.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index fb003b7538..2efd26f4aa 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -899,7 +899,7 @@ impl Session { .as_ref() .map(|pk| prepared.get_partitioner_name().hash(pk)); - let statement_info = self.routing_info(prepared, token); + let statement_info = self.routing_info_from_prepared_statement(prepared, token); let span = RequestSpan::new_prepared(partition_key.as_ref(), token, serialized_values.size()); @@ -1825,7 +1825,7 @@ impl Session { Some(token) => token, None => return Ok(None), }; - let routing_info = self.routing_info(prepared, Some(token)); + let routing_info = self.routing_info_from_prepared_statement(prepared, Some(token)); let cluster_data = self.cluster.get_data(); let execution_profile = prepared .config @@ -1852,7 +1852,7 @@ impl Session { })) } - fn routing_info<'p>( + fn routing_info_from_prepared_statement<'p>( &self, prepared: &'p PreparedStatement, token: Option, From 146655e6bf4b73958b1c7a3f9f0c5504a9d3ad49 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sun, 2 Jul 2023 14:53:29 +0200 Subject: [PATCH 04/14] add enforce_target_node on batch --- scylla/src/statement/batch.rs | 79 ++++++++++++++++++- scylla/src/transport/execution_profile.rs | 4 + .../transport/load_balancing/enforce_node.rs | 42 ++++++++++ scylla/src/transport/load_balancing/mod.rs | 6 +- scylla/src/transport/session.rs | 6 +- 5 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 scylla/src/transport/load_balancing/enforce_node.rs diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index ab120aea62..01fd171477 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use crate::history::HistoryListener; use crate::retry_policy::RetryPolicy; use crate::statement::{prepared_statement::PreparedStatement, query::Query}; -use crate::transport::execution_profile::ExecutionProfileHandle; +use crate::transport::{execution_profile::ExecutionProfileHandle, Node}; +use crate::Session; use super::StatementConfig; pub use super::{Consistency, SerialConsistency}; @@ -144,6 +145,82 @@ impl Batch { pub fn get_execution_profile_handle(&self) -> Option<&ExecutionProfileHandle> { self.config.execution_profile_handle.as_ref() } + + /// Associates the batch with a new execution profile that will have a load balancing policy + /// that will enforce the use of the provided [`Node`] to the extent possible. + /// + /// This should typically be used in conjunction with [`Session::shard_for_statement`], where + /// you would constitute a batch by assigning to the same batch all the statements that would be executed in + /// the same shard. + /// + /// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement + /// to the same node, you should use this method to enforce the use of the original node that was envisioned by + /// `shard_for_statement` for the batch: + /// + /// ```rust + /// # use scylla::Session; + /// # use std::error::Error; + /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { + /// use scylla::{ + /// batch::Batch, + /// frame::value::{SerializedValues, ValueList}, + /// }; + /// + /// let prepared_statement = session + /// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)") + /// .await?; + /// + /// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned(); + /// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?; + /// + /// // Send that to a task that will handle statements targeted to the same shard + /// + /// // On that task: + /// // Constitute a batch with all the statements that would be executed in the same shard + /// + /// let mut batch: Batch = Default::default(); + /// if let Some((node, _shard_idx)) = shard { + /// batch.enforce_target_node(&node, &session); + /// } + /// let mut batch_values = Vec::new(); + /// + /// // As the task handling statements targeted to this shard receives them, + /// // it appends them to the batch + /// batch.append_statement(prepared_statement); + /// batch_values.push(serialized_values); + /// + /// // Run the batch + /// session.batch(&batch, batch_values).await?; + /// # Ok(()) + /// # } + /// ``` + /// + /// + /// If the target node is not available anymore at the time of executing the statement, it will fallback to the + /// original load balancing policy: + /// - Either that currently set on the [`Batch`], if any + /// - Or that of the [`Session`] if there isn't one on the `Batch` + pub fn enforce_target_node( + &mut self, + node: &Arc, + base_execution_profile_from_session: &Session, + ) { + let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| { + base_execution_profile_from_session.get_default_execution_profile_handle() + }); + self.set_execution_profile_handle(Some( + execution_profile_handle + .pointee_to_builder() + .load_balancing_policy(Arc::new( + crate::load_balancing::EnforceTargetNodePolicy::new( + node, + execution_profile_handle.load_balancing_policy(), + ), + )) + .build() + .into_handle(), + )) + } } impl Default for Batch { diff --git a/scylla/src/transport/execution_profile.rs b/scylla/src/transport/execution_profile.rs index 245beffab9..7f92fb3b18 100644 --- a/scylla/src/transport/execution_profile.rs +++ b/scylla/src/transport/execution_profile.rs @@ -485,4 +485,8 @@ impl ExecutionProfileHandle { pub fn map_to_another_profile(&mut self, profile: ExecutionProfile) { self.0 .0.store(profile.0) } + + pub fn load_balancing_policy(&self) -> Arc { + self.0 .0.load().load_balancing_policy.clone() + } } diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs new file mode 100644 index 0000000000..12c17bbcbf --- /dev/null +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -0,0 +1,42 @@ +use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; +use crate::transport::{cluster::ClusterData, Node}; +use std::sync::Arc; + +#[derive(Debug)] +pub struct EnforceTargetNodePolicy { + target_node: uuid::Uuid, + fallback: Arc, +} + +impl EnforceTargetNodePolicy { + pub fn new(target_node: &Arc, fallback: Arc) -> Self { + Self { + target_node: target_node.host_id, + fallback, + } + } +} +impl LoadBalancingPolicy for EnforceTargetNodePolicy { + fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { + cluster + .known_peers + .get(&self.target_node) + .or_else(|| self.fallback.pick(query, cluster)) + } + + fn fallback<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> FallbackPlan<'a> { + self.fallback.fallback(query, cluster) + } + + fn name(&self) -> String { + format!( + "Enforce target node Load balancing policy - Node: {} - fallback: {}", + self.target_node, + self.fallback.name() + ) + } +} diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index d4095743c3..64760685e7 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -9,9 +9,13 @@ use scylla_cql::{errors::QueryError, frame::types}; use std::time::Duration; mod default; +mod enforce_node; mod plan; -pub use default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}; pub use plan::Plan; +pub use { + default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}, + enforce_node::EnforceTargetNodePolicy, +}; /// Represents info about statement that can be used by load balancing policies. #[derive(Default, Clone, Debug)] diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 2efd26f4aa..53ae000ed3 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -1812,11 +1812,11 @@ impl Session { .map(|partition_key| prepared.get_partitioner_name().hash(&partition_key))) } - /// Get the first node/shard that the load balancer would target if running this query + /// Get a node/shard that the load balancer would potentially target if running this query /// - /// This may help constituting shard-aware batches + /// This may help constituting shard-aware batches (see [`Batch::enforce_target_node`]) #[allow(clippy::type_complexity)] - pub fn first_shard_for_statement( + pub fn shard_for_statement( &self, prepared: &PreparedStatement, serialized_values: &SerializedValues, From 6406e7035eed22d1927ead73d269e3a26ac42a14 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sun, 2 Jul 2023 15:27:42 +0200 Subject: [PATCH 05/14] Filter that node is alive --- scylla/src/transport/load_balancing/default.rs | 2 +- scylla/src/transport/load_balancing/enforce_node.rs | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 8cac92d6e5..be0937528c 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -594,7 +594,7 @@ impl DefaultPolicy { vec.into_iter() } - fn is_alive(node: &NodeRef<'_>) -> bool { + pub(crate) fn is_alive(node: &NodeRef<'_>) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs index 12c17bbcbf..f3809603c9 100644 --- a/scylla/src/transport/load_balancing/enforce_node.rs +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -1,7 +1,11 @@ -use super::{FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; +use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; use crate::transport::{cluster::ClusterData, Node}; use std::sync::Arc; +/// This policy will always return the same node, unless it is not available anymore, in which case it will +/// fallback to the provided policy. +/// +/// This is meant to be used for shard-aware batching. #[derive(Debug)] pub struct EnforceTargetNodePolicy { target_node: uuid::Uuid, @@ -21,6 +25,7 @@ impl LoadBalancingPolicy for EnforceTargetNodePolicy { cluster .known_peers .get(&self.target_node) + .filter(DefaultPolicy::is_alive) .or_else(|| self.fallback.pick(query, cluster)) } From 84ee6aa34e7896e6e019527635817a429ac5ea75 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Mon, 3 Jul 2023 11:47:44 +0200 Subject: [PATCH 06/14] less qualifying more use --- scylla/src/statement/batch.rs | 11 +++++------ scylla/src/transport/load_balancing/enforce_node.rs | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index 01fd171477..7665a8198f 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use crate::history::HistoryListener; +use crate::load_balancing; use crate::retry_policy::RetryPolicy; use crate::statement::{prepared_statement::PreparedStatement, query::Query}; use crate::transport::{execution_profile::ExecutionProfileHandle, Node}; @@ -211,12 +212,10 @@ impl Batch { self.set_execution_profile_handle(Some( execution_profile_handle .pointee_to_builder() - .load_balancing_policy(Arc::new( - crate::load_balancing::EnforceTargetNodePolicy::new( - node, - execution_profile_handle.load_balancing_policy(), - ), - )) + .load_balancing_policy(Arc::new(load_balancing::EnforceTargetNodePolicy::new( + node, + execution_profile_handle.load_balancing_policy(), + ))) .build() .into_handle(), )) diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs index f3809603c9..3b56508498 100644 --- a/scylla/src/transport/load_balancing/enforce_node.rs +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -1,6 +1,7 @@ use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; use crate::transport::{cluster::ClusterData, Node}; use std::sync::Arc; +use uuid::Uuid; /// This policy will always return the same node, unless it is not available anymore, in which case it will /// fallback to the provided policy. @@ -8,7 +9,7 @@ use std::sync::Arc; /// This is meant to be used for shard-aware batching. #[derive(Debug)] pub struct EnforceTargetNodePolicy { - target_node: uuid::Uuid, + target_node: Uuid, fallback: Arc, } From a651fedc4f450b01f214d2e02c68fded661e964f Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Mon, 3 Jul 2023 15:08:18 +0200 Subject: [PATCH 07/14] fix doc being missing for ExecutionProfileHandle::load_balancing_policy --- scylla/src/transport/execution_profile.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/scylla/src/transport/execution_profile.rs b/scylla/src/transport/execution_profile.rs index 7f92fb3b18..97da37156d 100644 --- a/scylla/src/transport/execution_profile.rs +++ b/scylla/src/transport/execution_profile.rs @@ -486,7 +486,18 @@ impl ExecutionProfileHandle { self.0 .0.store(profile.0) } + /// Get the load balancing policy associated with this execution profile. + /// + /// This may be useful if one wants to construct a new load balancing policy + /// that is based on the one associated with this execution profile. pub fn load_balancing_policy(&self) -> Arc { + // Exposed as a building block of `Batch::enforce_target_node` in case a user + // wants more control than what that method does. + // Since the fact that the load balancing policy is accessible from the + // ExecutionProfileHandle is already public API through the fact it's documented + // that it would be preserved by pointee_to_builder, having this as pblic API + // doesn't prevent any more non-breaking evolution than would already be + // blocked anyway self.0 .0.load().load_balancing_policy.clone() } } From 59ced8f72420c0b2dae61f18ff83b969f7cdf2d7 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sun, 6 Aug 2023 13:10:51 +0200 Subject: [PATCH 08/14] wip example test --- scylla/Cargo.toml | 2 + scylla/tests/integration/main.rs | 1 + .../tests/integration/shard_aware_batching.rs | 193 ++++++++++++++++++ 3 files changed, 196 insertions(+) create mode 100644 scylla/tests/integration/shard_aware_batching.rs diff --git a/scylla/Cargo.toml b/scylla/Cargo.toml index f83bdbd240..d6ed424d51 100644 --- a/scylla/Cargo.toml +++ b/scylla/Cargo.toml @@ -60,6 +60,8 @@ criterion = "0.5" tracing-subscriber = { version = "0.3.14", features = ["env-filter"] } assert_matches = "1.5.0" rand_chacha = "0.3.1" +futures-batch = "0.6.1" +tokio-stream = "0.1.14" [[bench]] name = "benchmark" diff --git a/scylla/tests/integration/main.rs b/scylla/tests/integration/main.rs index c2cb8fc605..c4e4742cd3 100644 --- a/scylla/tests/integration/main.rs +++ b/scylla/tests/integration/main.rs @@ -4,4 +4,5 @@ mod hygiene; mod lwt_optimisation; mod new_session; mod retries; +mod shard_aware_batching; pub(crate) mod utils; diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs new file mode 100644 index 0000000000..ec0100fd66 --- /dev/null +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -0,0 +1,193 @@ +use crate::utils::test_with_3_node_cluster; +use futures::prelude::*; +use futures_batch::ChunksTimeoutStreamExt; +use scylla::frame::value::ValueList; +use scylla::retry_policy::FallthroughRetryPolicy; +use scylla::test_utils::unique_keyspace_name; +use scylla::transport::session::Session; +use scylla::{ExecutionProfile, SessionBuilder}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +use scylla_proxy::{ + Condition, ProxyError, Reaction, RequestFrame, RequestOpcode, RequestReaction, RequestRule, + RunningProxy, ShardAwareness, WorkerError, +}; + +#[tokio::test] +#[ntest::timeout(20000)] +#[cfg(not(scylla_cloud_tests))] +async fn shard_aware_batching_pattern_routes_to_proper_shard() { + let res = test_with_3_node_cluster(ShardAwareness::QueryNode, run_test).await; + + match res { + Ok(()) => (), + Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (), + Err(err) => panic!("{}", err), + } +} + +async fn run_test( + proxy_uris: [String; 3], + translation_map: HashMap, + mut running_proxy: RunningProxy, +) -> RunningProxy { + // This is just to increase the likelihood that only intended prepared statements (which contain this mark) are captured by the proxy. + const MAGIC_MARK: i32 = 123; + + // We set up proxy, so that it passes us information about which node was queried (via prepared_rx). + + let prepared_rule = |tx| { + RequestRule( + Condition::and( + Condition::RequestOpcode(RequestOpcode::Batch), + Condition::BodyContainsCaseSensitive(Box::new(MAGIC_MARK.to_be_bytes())), + ), + RequestReaction::noop().with_feedback_when_performed(tx), + ) + }; + + let mut prepared_rxs = [0, 1, 2].map(|i| { + let (prepared_tx, prepared_rx) = mpsc::unbounded_channel(); + running_proxy.running_nodes[i].change_request_rules(Some(vec![prepared_rule(prepared_tx)])); + prepared_rx + }); + + let handle = ExecutionProfile::builder() + .retry_policy(Box::new(FallthroughRetryPolicy)) + .build() + .into_handle(); + + // DB preparation phase + let session: Arc = Arc::new( + SessionBuilder::new() + .known_node(proxy_uris[0].as_str()) + .default_execution_profile_handle(handle) + .address_translator(Arc::new(translation_map)) + .build() + .await + .unwrap(), + ); + + // Create schema + let ks = unique_keyspace_name(); + session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3}}", ks), &[]).await.unwrap(); + session.use_keyspace(ks, false).await.unwrap(); + + session + .query("CREATE TABLE t (a int primary key, b int)", &[]) + .await + .unwrap(); + + // We will check which nodes where queries, for both LWT and non-LWT prepared statements. + let prepared_statement = session + .prepare("INSERT INTO t (a, b) VALUES (?, ?)") + .await + .unwrap(); + + assert!(prepared_statement.is_token_aware()); + + // Build the shard-aware batching system + + #[derive(Clone, Copy, PartialEq, Eq, Hash)] + struct DestinationShard { + node_id: uuid::Uuid, + shard_id_on_node: Option, + } + let mut channels_for_shards: HashMap< + DestinationShard, + tokio::sync::mpsc::Sender, + > = HashMap::new(); + let mut batching_tasks: Vec> = Vec::new(); // To make sure nothing panicked + for i in 0..150 { + let values = (i, MAGIC_MARK); + + let serialized_values = values + .serialized() + .expect("Failed to serialize values") + .into_owned(); + + let (node, shard_id_on_node) = session + .shard_for_statement(&prepared_statement, &serialized_values) + .expect("Error when getting shard for statement") + .expect("Query is not shard-aware"); + let destination_shard = DestinationShard { + node_id: node.host_id, + shard_id_on_node, + }; + + // Typically if lines may come from different places, the `shards` `HashMap` would be behind + // a mutex, but for this example we keep it simple. + // Create the task that constitutes and sends the batches for this shard if it doesn't already exist + + let sender = channels_for_shards + .entry(destination_shard) + .or_insert_with(|| { + let (sender, receiver) = tokio::sync::mpsc::channel(10000); + let prepared_statement = prepared_statement.clone(); + let session = session.clone(); + + let mut scylla_batch = + scylla::batch::Batch::new(scylla::batch::BatchType::Unlogged); + scylla_batch.enforce_target_node(&node, &session); + + batching_tasks.push(tokio::spawn(async move { + let mut batches = ReceiverStream::new(receiver) + .chunks_timeout(10, Duration::from_millis(100)); + + while let Some(batch) = batches.next().await { + // Obviously if the actual prepared statement depends on each element of the batch + // this requires adjustment + scylla_batch.statements.resize_with(batch.len(), || { + scylla::batch::BatchStatement::PreparedStatement( + prepared_statement.clone(), + ) + }); + + session + .batch(&scylla_batch, &batch) + .await + .expect("Query to send batch failed"); + } + })); + sender + }); + + sender + .send(serialized_values) + .await + .expect("Failed to send serialized values to dedicated channel"); + } + + // Let's drop the senders, which will ensure that all batches are sent immediately, + // then wait for all the tasks to finish, and ensure that there were no errors + // In a production setting these dynamically instantiated tasks may be monitored more easily + // by using e.g. `tokio_tasks_shutdown` + std::mem::drop(channels_for_shards); + for task in batching_tasks { + task.await.unwrap(); + } + + // finally check that batching was indeed shard-aware. + + // TODO + + // wip: make sure we did capture the queries to each node + fn clear_rxs(rxs: &mut [mpsc::UnboundedReceiver; 3]) { + for rx in rxs.iter_mut() { + while rx.try_recv().is_ok() {} + } + } + async fn assert_all_replicas_queried(rxs: &mut [mpsc::UnboundedReceiver; 3]) { + for rx in rxs.iter_mut() { + rx.recv().await.unwrap(); + } + clear_rxs(rxs); + } + assert_all_replicas_queried(&mut prepared_rxs).await; + + running_proxy +} From f81d87f6dc8f8212409a0763b3cf71d0f9570f83 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 19:58:49 +0200 Subject: [PATCH 09/14] adapt after master merge --- scylla-cql/src/types/serialize/row.rs | 38 +++- scylla/src/statement/batch.rs | 42 +++-- scylla/src/statement/prepared_statement.rs | 2 +- .../src/transport/load_balancing/default.rs | 2 +- .../transport/load_balancing/enforce_node.rs | 30 +++- scylla/src/transport/load_balancing/mod.rs | 2 +- scylla/src/transport/session.rs | 165 ++++++++++-------- .../tests/integration/shard_aware_batching.rs | 21 +-- 8 files changed, 186 insertions(+), 116 deletions(-) diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index f980064388..1cfafefd68 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -80,6 +80,17 @@ pub trait SerializeRow { /// the bind marker types and names so that the values can be properly /// type checked and serialized. fn is_empty(&self) -> bool; + + /// Specialization that allows the driver to not re-serialize the row if it's already + /// a `SerializedValues` + /// + /// Note that if using this, it's the user's responsibility to ensure that this + /// `SerializedValues` has been generated with the same prepared statement as the query + /// is going to be made with. + #[inline] + fn already_serialized(&self) -> Option<&SerializedValues> { + None + } } macro_rules! fallback_impl_contents { @@ -255,12 +266,35 @@ impl SerializeRow for &T { ctx: &RowSerializationContext<'_>, writer: &mut RowWriter, ) -> Result<(), SerializationError> { - ::serialize(self, ctx, writer) + ::serialize(*self, ctx, writer) } #[inline] fn is_empty(&self) -> bool { - ::is_empty(self) + ::is_empty(*self) + } + + #[inline] + fn already_serialized(&self) -> Option<&SerializedValues> { + ::already_serialized(*self) + } +} + +impl SerializeRow for SerializedValues { + fn serialize( + &self, + _ctx: &RowSerializationContext<'_>, + writer: &mut RowWriter, + ) -> Result<(), SerializationError> { + Ok(writer.append_serialize_row(self)) + } + + fn is_empty(&self) -> bool { + self.is_empty() + } + + fn already_serialized(&self) -> Option<&SerializedValues> { + Some(self) } } diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index 55b1f80669..762f709f14 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -4,8 +4,10 @@ use std::sync::Arc; use crate::history::HistoryListener; use crate::load_balancing; use crate::retry_policy::RetryPolicy; +use crate::routing::Shard; use crate::statement::{prepared_statement::PreparedStatement, query::Query}; -use crate::transport::{execution_profile::ExecutionProfileHandle, Node}; +use crate::transport::execution_profile::ExecutionProfileHandle; +use crate::transport::NodeRef; use crate::Session; use super::StatementConfig; @@ -145,31 +147,31 @@ impl Batch { self.config.execution_profile_handle.as_ref() } - /// Associates the batch with a new execution profile that will have a load balancing policy - /// that will enforce the use of the provided [`Node`] to the extent possible. + /// Associates the batch with a new execution profile that will have a load + /// balancing policy that will enforce the use of the provided [`Node`] + /// to the extent possible. /// - /// This should typically be used in conjunction with [`Session::shard_for_statement`], where - /// you would constitute a batch by assigning to the same batch all the statements that would be executed in - /// the same shard. + /// This should typically be used in conjunction with + /// [`Session::shard_for_statement`], where you would constitute a batch + /// by assigning to the same batch all the statements that would be executed + /// in the same shard. /// - /// Since it is not guaranteed that subsequent calls to the load balancer would re-assign the statement - /// to the same node, you should use this method to enforce the use of the original node that was envisioned by + /// Since it is not guaranteed that subsequent calls to the load balancer + /// would re-assign the statement to the same node, you should use this + /// method to enforce the use of the original node that was envisioned by /// `shard_for_statement` for the batch: /// /// ```rust /// # use scylla::Session; /// # use std::error::Error; /// # async fn check_only_compiles(session: &Session) -> Result<(), Box> { - /// use scylla::{ - /// batch::Batch, - /// frame::value::{SerializedValues, ValueList}, - /// }; + /// use scylla::{batch::Batch, serialize::row::SerializedValues}; /// /// let prepared_statement = session /// .prepare("INSERT INTO ks.tab(a, b) VALUES(?, ?)") /// .await?; /// - /// let serialized_values: SerializedValues = (1, 2).serialized()?.into_owned(); + /// let serialized_values: SerializedValues = prepared_statement.serialize_values(&(1, 2))?; /// let shard = session.shard_for_statement(&prepared_statement, &serialized_values)?; /// /// // Send that to a task that will handle statements targeted to the same shard @@ -178,8 +180,8 @@ impl Batch { /// // Constitute a batch with all the statements that would be executed in the same shard /// /// let mut batch: Batch = Default::default(); - /// if let Some((node, _shard_idx)) = shard { - /// batch.enforce_target_node(&node, &session); + /// if let Some((node, shard_idx)) = shard { + /// batch.enforce_target_node(&node, shard_idx, &session); /// } /// let mut batch_values = Vec::new(); /// @@ -195,13 +197,14 @@ impl Batch { /// ``` /// /// - /// If the target node is not available anymore at the time of executing the statement, it will fallback to the - /// original load balancing policy: + /// If the target node is not available anymore at the time of executing the + /// statement, it will fallback to the original load balancing policy: /// - Either that currently set on the [`Batch`], if any /// - Or that of the [`Session`] if there isn't one on the `Batch` pub fn enforce_target_node( &mut self, - node: &Arc, + node: NodeRef<'_>, + shard: Shard, base_execution_profile_from_session: &Session, ) { let execution_profile_handle = self.get_execution_profile_handle().unwrap_or_else(|| { @@ -210,8 +213,9 @@ impl Batch { self.set_execution_profile_handle(Some( execution_profile_handle .pointee_to_builder() - .load_balancing_policy(Arc::new(load_balancing::EnforceTargetNodePolicy::new( + .load_balancing_policy(Arc::new(load_balancing::EnforceTargetShardPolicy::new( node, + shard, execution_profile_handle.load_balancing_policy(), ))) .build() diff --git a/scylla/src/statement/prepared_statement.rs b/scylla/src/statement/prepared_statement.rs index 6287a1492e..bbbdf1e39b 100644 --- a/scylla/src/statement/prepared_statement.rs +++ b/scylla/src/statement/prepared_statement.rs @@ -459,7 +459,7 @@ impl PreparedStatement { self.config.execution_profile_handle.as_ref() } - pub(crate) fn serialize_values( + pub fn serialize_values( &self, values: &impl SerializeRow, ) -> Result { diff --git a/scylla/src/transport/load_balancing/default.rs b/scylla/src/transport/load_balancing/default.rs index 4280c855fa..062de20926 100644 --- a/scylla/src/transport/load_balancing/default.rs +++ b/scylla/src/transport/load_balancing/default.rs @@ -703,7 +703,7 @@ impl DefaultPolicy { vec.into_iter() } - fn is_alive(node: NodeRef, _shard: Option) -> bool { + pub(crate) fn is_alive(node: NodeRef, _shard: Option) -> bool { // For now, we leave this as stub, until we have time to improve node events. // node.is_enabled() && !node.is_down() node.is_enabled() diff --git a/scylla/src/transport/load_balancing/enforce_node.rs b/scylla/src/transport/load_balancing/enforce_node.rs index 3b56508498..e2e7dde11c 100644 --- a/scylla/src/transport/load_balancing/enforce_node.rs +++ b/scylla/src/transport/load_balancing/enforce_node.rs @@ -1,5 +1,8 @@ use super::{DefaultPolicy, FallbackPlan, LoadBalancingPolicy, NodeRef, RoutingInfo}; -use crate::transport::{cluster::ClusterData, Node}; +use crate::{ + routing::Shard, + transport::{cluster::ClusterData, Node}, +}; use std::sync::Arc; use uuid::Uuid; @@ -8,25 +11,36 @@ use uuid::Uuid; /// /// This is meant to be used for shard-aware batching. #[derive(Debug)] -pub struct EnforceTargetNodePolicy { +pub struct EnforceTargetShardPolicy { target_node: Uuid, + shard: Shard, fallback: Arc, } -impl EnforceTargetNodePolicy { - pub fn new(target_node: &Arc, fallback: Arc) -> Self { +impl EnforceTargetShardPolicy { + pub fn new( + target_node: &Arc, + shard: Shard, + fallback: Arc, + ) -> Self { Self { target_node: target_node.host_id, + shard, fallback, } } } -impl LoadBalancingPolicy for EnforceTargetNodePolicy { - fn pick<'a>(&'a self, query: &'a RoutingInfo, cluster: &'a ClusterData) -> Option> { +impl LoadBalancingPolicy for EnforceTargetShardPolicy { + fn pick<'a>( + &'a self, + query: &'a RoutingInfo, + cluster: &'a ClusterData, + ) -> Option<(NodeRef<'a>, Option)> { cluster .known_peers .get(&self.target_node) - .filter(DefaultPolicy::is_alive) + .map(|node| (node, Some(self.shard))) + .filter(|&(node, shard)| DefaultPolicy::is_alive(node, shard)) .or_else(|| self.fallback.pick(query, cluster)) } @@ -40,7 +54,7 @@ impl LoadBalancingPolicy for EnforceTargetNodePolicy { fn name(&self) -> String { format!( - "Enforce target node Load balancing policy - Node: {} - fallback: {}", + "Enforce target shard Load balancing policy - Node: {} - fallback: {}", self.target_node, self.fallback.name() ) diff --git a/scylla/src/transport/load_balancing/mod.rs b/scylla/src/transport/load_balancing/mod.rs index 1f2cf70c2c..7c5e6769fd 100644 --- a/scylla/src/transport/load_balancing/mod.rs +++ b/scylla/src/transport/load_balancing/mod.rs @@ -17,7 +17,7 @@ mod plan; pub use plan::Plan; pub use { default::{DefaultPolicy, DefaultPolicyBuilder, LatencyAwarenessBuilder}, - enforce_node::EnforceTargetNodePolicy, + enforce_node::EnforceTargetShardPolicy, }; /// Represents info about statement that can be used by load balancing policies. diff --git a/scylla/src/transport/session.rs b/scylla/src/transport/session.rs index 752b1fecd3..2385ea3c3e 100644 --- a/scylla/src/transport/session.rs +++ b/scylla/src/transport/session.rs @@ -19,8 +19,8 @@ pub use scylla_cql::errors::TranslationError; use scylla_cql::frame::response::result::{deser_cql_value, ColumnSpec, Rows}; use scylla_cql::frame::response::NonErrorResponse; use scylla_cql::types::serialize::batch::BatchValues; -use scylla_cql::types::serialize::row::SerializeRow; -use std::borrow::Borrow; +use scylla_cql::types::serialize::row::{SerializeRow, SerializedValues}; +use std::borrow::{Borrow, Cow}; use std::collections::HashMap; use std::fmt::Display; use std::future::Future; @@ -985,78 +985,92 @@ impl Session { values: impl SerializeRow, paging_state: Option, ) -> Result { - let serialized_values = prepared.serialize_values(&values)?; - let values_ref = &serialized_values; - let paging_state_ref = &paging_state; - - let (partition_key, token) = prepared - .extract_partition_key_and_calculate_token(prepared.get_partitioner_name(), values_ref)? - .unzip(); - - let (execution_profile, statement_info) = - self.execution_profile_and_routing_info_from_prepared_statement(prepared, token); + // Inlining + const propagation make this optimization zero-cost in unrelated cases: + let serialized_values = match values.already_serialized() { + None => Cow::Owned(prepared.serialize_values(&values)?), + Some(serialized) => Cow::Borrowed(serialized), + }; - let span = RequestSpan::new_prepared( - partition_key.as_ref().map(|pk| pk.iter()), - token, - serialized_values.buffer_size(), - ); + return non_generic_inner(self, prepared, &serialized_values, &paging_state).await; + /// Avoid monomorphizing this whole function for every SerializeRow type + async fn non_generic_inner( + self_: &Session, + prepared: &PreparedStatement, + values_ref: &SerializedValues, + paging_state_ref: &Option, + ) -> Result { + let (partition_key, token) = prepared + .extract_partition_key_and_calculate_token( + prepared.get_partitioner_name(), + values_ref, + )? + .unzip(); + + let (execution_profile, statement_info) = + self_.execution_profile_and_routing_info_from_prepared_statement(prepared, token); + + let span = RequestSpan::new_prepared( + partition_key.as_ref().map(|pk| pk.iter()), + token, + values_ref.buffer_size(), + ); - if !span.span().is_disabled() { - if let (Some(table_spec), Some(token)) = (statement_info.table, token) { - let cluster_data = self.get_cluster_data(); - let replicas: smallvec::SmallVec<[_; 8]> = cluster_data - .get_token_endpoints_iter(table_spec, token) - .collect(); - span.record_replicas(&replicas) + if !span.span().is_disabled() { + if let (Some(table_spec), Some(token)) = (statement_info.table, token) { + let cluster_data = self_.get_cluster_data(); + let replicas: smallvec::SmallVec<[_; 8]> = cluster_data + .get_token_endpoints_iter(table_spec, token) + .collect(); + span.record_replicas(&replicas) + } } - } - let run_query_result: RunQueryResult = self - .run_query( - statement_info, - &prepared.config, - execution_profile, - |connection: Arc, - consistency: Consistency, - execution_profile: &ExecutionProfileInner| { - let serial_consistency = prepared - .config - .serial_consistency - .unwrap_or(execution_profile.serial_consistency); - async move { - connection - .execute_with_consistency( - prepared, - values_ref, - consistency, - serial_consistency, - paging_state_ref.clone(), - ) - .await - .and_then(QueryResponse::into_non_error_query_response) - } - }, - &span, - ) - .instrument(span.span().clone()) - .await?; + let run_query_result: RunQueryResult = self_ + .run_query( + statement_info, + &prepared.config, + execution_profile, + |connection: Arc, + consistency: Consistency, + execution_profile: &ExecutionProfileInner| { + let serial_consistency = prepared + .config + .serial_consistency + .unwrap_or(execution_profile.serial_consistency); + async move { + connection + .execute_with_consistency( + prepared, + values_ref, + consistency, + serial_consistency, + paging_state_ref.clone(), + ) + .await + .and_then(QueryResponse::into_non_error_query_response) + } + }, + &span, + ) + .instrument(span.span().clone()) + .await?; - let response = match run_query_result { - RunQueryResult::IgnoredWriteError => NonErrorQueryResponse { - response: NonErrorResponse::Result(result::Result::Void), - tracing_id: None, - warnings: Vec::new(), - }, - RunQueryResult::Completed(response) => response, - }; + let response = match run_query_result { + RunQueryResult::IgnoredWriteError => NonErrorQueryResponse { + response: NonErrorResponse::Result(result::Result::Void), + tracing_id: None, + warnings: Vec::new(), + }, + RunQueryResult::Completed(response) => response, + }; - self.handle_set_keyspace_response(&response).await?; - self.handle_auto_await_schema_agreement(&response).await?; + self_.handle_set_keyspace_response(&response).await?; + self_.handle_auto_await_schema_agreement(&response).await?; - let result = response.into_query_result()?; - span.record_result_fields(&result); - Ok(result) + let result = response.into_query_result()?; + span.record_result_fields(&result); + Ok(result) + } } /// Run a prepared query with paging\ @@ -1813,11 +1827,15 @@ impl Session { &self, prepared: &PreparedStatement, serialized_values: &SerializedValues, - ) -> Result, Option)>, QueryError> { - let token = match self.calculate_token(prepared, serialized_values)? { - Some(token) => token, + ) -> Result, routing::Shard)>, QueryError> { + let token = match prepared.extract_partition_key_and_calculate_token( + prepared.get_partitioner_name(), + serialized_values, + )? { + Some((_partition_key, token)) => token, None => return Ok(None), }; + let (execution_profile, routing_info) = self.execution_profile_and_routing_info_from_prepared_statement(prepared, Some(token)); let cluster_data = self.cluster.get_data(); @@ -1835,10 +1853,9 @@ impl Session { // Most likely though, people would use this for some kind of shard-awareness optimization for batching, // and are consequently not interested in subsequent nodes. // Until then, let's just expose this, as it is simpler - Ok(query_plan.next().map(move |node| { - let token = node.sharder().map(|sharder| sharder.shard_of(token)); - (node.clone(), token) - })) + Ok(query_plan + .next() + .map(|(node, shard)| (Arc::clone(node), shard))) } fn execution_profile_and_routing_info_from_prepared_statement<'p>( diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs index ec0100fd66..39924d5cc2 100644 --- a/scylla/tests/integration/shard_aware_batching.rs +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -1,8 +1,8 @@ use crate::utils::test_with_3_node_cluster; use futures::prelude::*; use futures_batch::ChunksTimeoutStreamExt; -use scylla::frame::value::ValueList; use scylla::retry_policy::FallthroughRetryPolicy; +use scylla::serialize::row::SerializedValues; use scylla::test_utils::unique_keyspace_name; use scylla::transport::session::Session; use scylla::{ExecutionProfile, SessionBuilder}; @@ -95,20 +95,19 @@ async fn run_test( #[derive(Clone, Copy, PartialEq, Eq, Hash)] struct DestinationShard { node_id: uuid::Uuid, - shard_id_on_node: Option, + shard_id_on_node: u32, } let mut channels_for_shards: HashMap< DestinationShard, - tokio::sync::mpsc::Sender, + tokio::sync::mpsc::Sender, > = HashMap::new(); let mut batching_tasks: Vec> = Vec::new(); // To make sure nothing panicked for i in 0..150 { let values = (i, MAGIC_MARK); - let serialized_values = values - .serialized() - .expect("Failed to serialize values") - .into_owned(); + let serialized_values = prepared_statement + .serialize_values(&values) + .expect("Failed to serialize values"); let (node, shard_id_on_node) = session .shard_for_statement(&prepared_statement, &serialized_values) @@ -132,7 +131,7 @@ async fn run_test( let mut scylla_batch = scylla::batch::Batch::new(scylla::batch::BatchType::Unlogged); - scylla_batch.enforce_target_node(&node, &session); + scylla_batch.enforce_target_node(&node, shard_id_on_node, &session); batching_tasks.push(tokio::spawn(async move { let mut batches = ReceiverStream::new(receiver) @@ -176,12 +175,14 @@ async fn run_test( // TODO // wip: make sure we did capture the queries to each node - fn clear_rxs(rxs: &mut [mpsc::UnboundedReceiver; 3]) { + fn clear_rxs(rxs: &mut [mpsc::UnboundedReceiver<(RequestFrame, Option)>; 3]) { for rx in rxs.iter_mut() { while rx.try_recv().is_ok() {} } } - async fn assert_all_replicas_queried(rxs: &mut [mpsc::UnboundedReceiver; 3]) { + async fn assert_all_replicas_queried( + rxs: &mut [mpsc::UnboundedReceiver<(RequestFrame, Option)>; 3], + ) { for rx in rxs.iter_mut() { rx.recv().await.unwrap(); } From fe8af81d9998eb54e8ca660b68f7012c20343814 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 22:15:17 +0200 Subject: [PATCH 10/14] test shard-awareness following pattern usage --- scylla-cql/src/types/serialize/row.rs | 3 +- .../tests/integration/shard_aware_batching.rs | 62 ++++++++++++++----- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/scylla-cql/src/types/serialize/row.rs b/scylla-cql/src/types/serialize/row.rs index 1cfafefd68..6b968060a7 100644 --- a/scylla-cql/src/types/serialize/row.rs +++ b/scylla-cql/src/types/serialize/row.rs @@ -286,7 +286,8 @@ impl SerializeRow for SerializedValues { _ctx: &RowSerializationContext<'_>, writer: &mut RowWriter, ) -> Result<(), SerializationError> { - Ok(writer.append_serialize_row(self)) + writer.append_serialize_row(self); + Ok(()) } fn is_empty(&self) -> bool { diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs index 39924d5cc2..a5c4040670 100644 --- a/scylla/tests/integration/shard_aware_batching.rs +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -2,6 +2,7 @@ use crate::utils::test_with_3_node_cluster; use futures::prelude::*; use futures_batch::ChunksTimeoutStreamExt; use scylla::retry_policy::FallthroughRetryPolicy; +use scylla::routing::Shard; use scylla::serialize::row::SerializedValues; use scylla::test_utils::unique_keyspace_name; use scylla::transport::session::Session; @@ -13,8 +14,8 @@ use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use scylla_proxy::{ - Condition, ProxyError, Reaction, RequestFrame, RequestOpcode, RequestReaction, RequestRule, - RunningProxy, ShardAwareness, WorkerError, + Condition, ProxyError, Reaction, RequestOpcode, RequestReaction, RequestRule, RunningProxy, + ShardAwareness, WorkerError, }; #[tokio::test] @@ -55,6 +56,8 @@ async fn run_test( running_proxy.running_nodes[i].change_request_rules(Some(vec![prepared_rule(prepared_tx)])); prepared_rx }); + let shards_for_nodes_test_check: Arc>>> = + Default::default(); let handle = ExecutionProfile::builder() .retry_policy(Box::new(FallthroughRetryPolicy)) @@ -133,6 +136,7 @@ async fn run_test( scylla::batch::Batch::new(scylla::batch::BatchType::Unlogged); scylla_batch.enforce_target_node(&node, shard_id_on_node, &session); + let shards_for_nodes_test_check_clone = Arc::clone(&shards_for_nodes_test_check); batching_tasks.push(tokio::spawn(async move { let mut batches = ReceiverStream::new(receiver) .chunks_timeout(10, Duration::from_millis(100)); @@ -150,6 +154,13 @@ async fn run_test( .batch(&scylla_batch, &batch) .await .expect("Query to send batch failed"); + + shards_for_nodes_test_check_clone + .lock() + .await + .entry(destination_shard.node_id) + .or_default() + .push(destination_shard.shard_id_on_node); } })); sender @@ -172,23 +183,40 @@ async fn run_test( // finally check that batching was indeed shard-aware. - // TODO - - // wip: make sure we did capture the queries to each node - fn clear_rxs(rxs: &mut [mpsc::UnboundedReceiver<(RequestFrame, Option)>; 3]) { - for rx in rxs.iter_mut() { - while rx.try_recv().is_ok() {} + let mut expected: Vec> = Arc::try_unwrap(shards_for_nodes_test_check) + .expect("All batching tasks have finished") + .into_inner() + .into_values() + .collect(); + + let mut nodes_shards_calls: Vec> = Vec::new(); + for rx in prepared_rxs.iter_mut() { + let mut shards_calls = Vec::new(); + shards_calls.push( + rx.recv() + .await + .expect("Each node should have received at least one message") + .1 + .expect("Calls should be shard-aware") + .into(), + ); + loop { + match rx.try_recv() { + Ok((_, call_shard)) => { + shards_calls.push(call_shard.expect("Calls should all be shard-aware").into()) + } + Err(_) => break, + } } + nodes_shards_calls.push(shards_calls); } - async fn assert_all_replicas_queried( - rxs: &mut [mpsc::UnboundedReceiver<(RequestFrame, Option)>; 3], - ) { - for rx in rxs.iter_mut() { - rx.recv().await.unwrap(); - } - clear_rxs(rxs); - } - assert_all_replicas_queried(&mut prepared_rxs).await; + + // Don't know which node is which + // but at least once we don't care about which node is which they should agree about what was sent to what shard + dbg!(&expected, &nodes_shards_calls); + expected.sort_unstable(); + nodes_shards_calls.sort_unstable(); + assert_eq!(expected, nodes_shards_calls); running_proxy } From 515417e50f8acfca33468c44e3c508747bf5d787 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 22:20:00 +0200 Subject: [PATCH 11/14] fix doclink --- scylla/src/statement/batch.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scylla/src/statement/batch.rs b/scylla/src/statement/batch.rs index 762f709f14..d668c32d65 100644 --- a/scylla/src/statement/batch.rs +++ b/scylla/src/statement/batch.rs @@ -148,7 +148,7 @@ impl Batch { } /// Associates the batch with a new execution profile that will have a load - /// balancing policy that will enforce the use of the provided [`Node`] + /// balancing policy that will enforce the use of the provided [`NodeRef`] /// to the extent possible. /// /// This should typically be used in conjunction with From f062b4dee89759fca18246ade4b363e8196923af Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 22:20:10 +0200 Subject: [PATCH 12/14] Fix Cargo.lock.msrv --- Cargo.lock.msrv | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/Cargo.lock.msrv b/Cargo.lock.msrv index 4ddbe9dae1..0bc434d30e 100644 --- a/Cargo.lock.msrv +++ b/Cargo.lock.msrv @@ -586,6 +586,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-batch" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f444c45a1cb86f2a7e301469fd50a82084a60dadc25d94529a8312276ecb71a" +dependencies = [ + "futures", + "futures-timer", + "pin-utils", +] + [[package]] name = "futures-channel" version = "0.3.28" @@ -642,6 +653,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.28" @@ -1490,6 +1507,7 @@ dependencies = [ "criterion", "dashmap", "futures", + "futures-batch", "hashbrown 0.14.0", "histogram", "itertools 0.11.0", @@ -1514,6 +1532,7 @@ dependencies = [ "time", "tokio", "tokio-openssl", + "tokio-stream", "tracing", "tracing-subscriber", "url", @@ -1861,6 +1880,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "267ac89e0bec6e691e5813911606935d77c476ff49024f98abcea3e7b15e37af" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml_datetime" version = "0.6.3" From 10a8a862225ab6bb27e9136fb29aec8798584a27 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 22:35:09 +0200 Subject: [PATCH 13/14] test fixes --- .../tests/integration/shard_aware_batching.rs | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs index a5c4040670..cf624cb3e9 100644 --- a/scylla/tests/integration/shard_aware_batching.rs +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -197,16 +197,21 @@ async fn run_test( .await .expect("Each node should have received at least one message") .1 - .expect("Calls should be shard-aware") + .unwrap_or({ + // Cassandra case (non-scylla) + 0 + }) .into(), ); - loop { - match rx.try_recv() { - Ok((_, call_shard)) => { - shards_calls.push(call_shard.expect("Calls should all be shard-aware").into()) - } - Err(_) => break, - } + while let Ok((_, call_shard)) = rx.try_recv() { + shards_calls.push( + call_shard + .unwrap_or({ + // Cassandra case (non-scylla) + 0 + }) + .into(), + ) } nodes_shards_calls.push(shards_calls); } From d5f71a93ef73cac47c4d513ca4f2d231c168f7e5 Mon Sep 17 00:00:00 2001 From: Thomas BESSOU Date: Sat, 25 May 2024 22:38:16 +0200 Subject: [PATCH 14/14] make test deterministic --- scylla/tests/integration/shard_aware_batching.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scylla/tests/integration/shard_aware_batching.rs b/scylla/tests/integration/shard_aware_batching.rs index cf624cb3e9..5bae286bce 100644 --- a/scylla/tests/integration/shard_aware_batching.rs +++ b/scylla/tests/integration/shard_aware_batching.rs @@ -150,14 +150,17 @@ async fn run_test( ) }); + // Take a global lock to make test deterministic + // (and because we need to push stuff in there to test that shard-awareness is respected) + let mut shards_for_nodes_test_check = + shards_for_nodes_test_check_clone.lock().await; + session .batch(&scylla_batch, &batch) .await .expect("Query to send batch failed"); - shards_for_nodes_test_check_clone - .lock() - .await + shards_for_nodes_test_check .entry(destination_shard.node_id) .or_default() .push(destination_shard.shard_id_on_node);