Skip to content

Commit 59ced8f

Browse files
committed
wip example test
1 parent 128f741 commit 59ced8f

File tree

3 files changed

+196
-0
lines changed

3 files changed

+196
-0
lines changed

scylla/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ criterion = "0.5"
6060
tracing-subscriber = { version = "0.3.14", features = ["env-filter"] }
6161
assert_matches = "1.5.0"
6262
rand_chacha = "0.3.1"
63+
futures-batch = "0.6.1"
64+
tokio-stream = "0.1.14"
6365

6466
[[bench]]
6567
name = "benchmark"

scylla/tests/integration/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ mod hygiene;
44
mod lwt_optimisation;
55
mod new_session;
66
mod retries;
7+
mod shard_aware_batching;
78
pub(crate) mod utils;
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
use crate::utils::test_with_3_node_cluster;
2+
use futures::prelude::*;
3+
use futures_batch::ChunksTimeoutStreamExt;
4+
use scylla::frame::value::ValueList;
5+
use scylla::retry_policy::FallthroughRetryPolicy;
6+
use scylla::test_utils::unique_keyspace_name;
7+
use scylla::transport::session::Session;
8+
use scylla::{ExecutionProfile, SessionBuilder};
9+
use std::collections::HashMap;
10+
use std::sync::Arc;
11+
use std::time::Duration;
12+
use tokio::sync::mpsc;
13+
use tokio_stream::wrappers::ReceiverStream;
14+
15+
use scylla_proxy::{
16+
Condition, ProxyError, Reaction, RequestFrame, RequestOpcode, RequestReaction, RequestRule,
17+
RunningProxy, ShardAwareness, WorkerError,
18+
};
19+
20+
#[tokio::test]
21+
#[ntest::timeout(20000)]
22+
#[cfg(not(scylla_cloud_tests))]
23+
async fn shard_aware_batching_pattern_routes_to_proper_shard() {
24+
let res = test_with_3_node_cluster(ShardAwareness::QueryNode, run_test).await;
25+
26+
match res {
27+
Ok(()) => (),
28+
Err(ProxyError::Worker(WorkerError::DriverDisconnected(_))) => (),
29+
Err(err) => panic!("{}", err),
30+
}
31+
}
32+
33+
async fn run_test(
34+
proxy_uris: [String; 3],
35+
translation_map: HashMap<std::net::SocketAddr, std::net::SocketAddr>,
36+
mut running_proxy: RunningProxy,
37+
) -> RunningProxy {
38+
// This is just to increase the likelihood that only intended prepared statements (which contain this mark) are captured by the proxy.
39+
const MAGIC_MARK: i32 = 123;
40+
41+
// We set up proxy, so that it passes us information about which node was queried (via prepared_rx).
42+
43+
let prepared_rule = |tx| {
44+
RequestRule(
45+
Condition::and(
46+
Condition::RequestOpcode(RequestOpcode::Batch),
47+
Condition::BodyContainsCaseSensitive(Box::new(MAGIC_MARK.to_be_bytes())),
48+
),
49+
RequestReaction::noop().with_feedback_when_performed(tx),
50+
)
51+
};
52+
53+
let mut prepared_rxs = [0, 1, 2].map(|i| {
54+
let (prepared_tx, prepared_rx) = mpsc::unbounded_channel();
55+
running_proxy.running_nodes[i].change_request_rules(Some(vec![prepared_rule(prepared_tx)]));
56+
prepared_rx
57+
});
58+
59+
let handle = ExecutionProfile::builder()
60+
.retry_policy(Box::new(FallthroughRetryPolicy))
61+
.build()
62+
.into_handle();
63+
64+
// DB preparation phase
65+
let session: Arc<Session> = Arc::new(
66+
SessionBuilder::new()
67+
.known_node(proxy_uris[0].as_str())
68+
.default_execution_profile_handle(handle)
69+
.address_translator(Arc::new(translation_map))
70+
.build()
71+
.await
72+
.unwrap(),
73+
);
74+
75+
// Create schema
76+
let ks = unique_keyspace_name();
77+
session.query(format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{'class' : 'NetworkTopologyStrategy', 'replication_factor' : 3}}", ks), &[]).await.unwrap();
78+
session.use_keyspace(ks, false).await.unwrap();
79+
80+
session
81+
.query("CREATE TABLE t (a int primary key, b int)", &[])
82+
.await
83+
.unwrap();
84+
85+
// We will check which nodes where queries, for both LWT and non-LWT prepared statements.
86+
let prepared_statement = session
87+
.prepare("INSERT INTO t (a, b) VALUES (?, ?)")
88+
.await
89+
.unwrap();
90+
91+
assert!(prepared_statement.is_token_aware());
92+
93+
// Build the shard-aware batching system
94+
95+
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
96+
struct DestinationShard {
97+
node_id: uuid::Uuid,
98+
shard_id_on_node: Option<u32>,
99+
}
100+
let mut channels_for_shards: HashMap<
101+
DestinationShard,
102+
tokio::sync::mpsc::Sender<scylla::frame::value::SerializedValues>,
103+
> = HashMap::new();
104+
let mut batching_tasks: Vec<tokio::task::JoinHandle<()>> = Vec::new(); // To make sure nothing panicked
105+
for i in 0..150 {
106+
let values = (i, MAGIC_MARK);
107+
108+
let serialized_values = values
109+
.serialized()
110+
.expect("Failed to serialize values")
111+
.into_owned();
112+
113+
let (node, shard_id_on_node) = session
114+
.shard_for_statement(&prepared_statement, &serialized_values)
115+
.expect("Error when getting shard for statement")
116+
.expect("Query is not shard-aware");
117+
let destination_shard = DestinationShard {
118+
node_id: node.host_id,
119+
shard_id_on_node,
120+
};
121+
122+
// Typically if lines may come from different places, the `shards` `HashMap` would be behind
123+
// a mutex, but for this example we keep it simple.
124+
// Create the task that constitutes and sends the batches for this shard if it doesn't already exist
125+
126+
let sender = channels_for_shards
127+
.entry(destination_shard)
128+
.or_insert_with(|| {
129+
let (sender, receiver) = tokio::sync::mpsc::channel(10000);
130+
let prepared_statement = prepared_statement.clone();
131+
let session = session.clone();
132+
133+
let mut scylla_batch =
134+
scylla::batch::Batch::new(scylla::batch::BatchType::Unlogged);
135+
scylla_batch.enforce_target_node(&node, &session);
136+
137+
batching_tasks.push(tokio::spawn(async move {
138+
let mut batches = ReceiverStream::new(receiver)
139+
.chunks_timeout(10, Duration::from_millis(100));
140+
141+
while let Some(batch) = batches.next().await {
142+
// Obviously if the actual prepared statement depends on each element of the batch
143+
// this requires adjustment
144+
scylla_batch.statements.resize_with(batch.len(), || {
145+
scylla::batch::BatchStatement::PreparedStatement(
146+
prepared_statement.clone(),
147+
)
148+
});
149+
150+
session
151+
.batch(&scylla_batch, &batch)
152+
.await
153+
.expect("Query to send batch failed");
154+
}
155+
}));
156+
sender
157+
});
158+
159+
sender
160+
.send(serialized_values)
161+
.await
162+
.expect("Failed to send serialized values to dedicated channel");
163+
}
164+
165+
// Let's drop the senders, which will ensure that all batches are sent immediately,
166+
// then wait for all the tasks to finish, and ensure that there were no errors
167+
// In a production setting these dynamically instantiated tasks may be monitored more easily
168+
// by using e.g. `tokio_tasks_shutdown`
169+
std::mem::drop(channels_for_shards);
170+
for task in batching_tasks {
171+
task.await.unwrap();
172+
}
173+
174+
// finally check that batching was indeed shard-aware.
175+
176+
// TODO
177+
178+
// wip: make sure we did capture the queries to each node
179+
fn clear_rxs(rxs: &mut [mpsc::UnboundedReceiver<RequestFrame>; 3]) {
180+
for rx in rxs.iter_mut() {
181+
while rx.try_recv().is_ok() {}
182+
}
183+
}
184+
async fn assert_all_replicas_queried(rxs: &mut [mpsc::UnboundedReceiver<RequestFrame>; 3]) {
185+
for rx in rxs.iter_mut() {
186+
rx.recv().await.unwrap();
187+
}
188+
clear_rxs(rxs);
189+
}
190+
assert_all_replicas_queried(&mut prepared_rxs).await;
191+
192+
running_proxy
193+
}

0 commit comments

Comments
 (0)