Skip to content

Commit ed7299f

Browse files
committed
enum for CAS states, use paginator, add configuration for strong consistency
Signed-off-by: Darwin Boersma <[email protected]>
1 parent de4f334 commit ed7299f

File tree

2 files changed

+129
-91
lines changed

2 files changed

+129
-91
lines changed

crates/key-value-aws/src/lib.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ pub struct AwsDynamoKeyValueRuntimeConfig {
3030
token: Option<String>,
3131
/// The AWS region where the database is located
3232
region: String,
33+
/// Boolean determining whether to use strongly consistent reads.
34+
/// Defaults to `false` but can be set to `true` to improve atomicity
35+
consistent_read: Option<bool>,
3336
/// The AWS Dynamo DB table.
3437
table: String,
3538
}
@@ -50,6 +53,7 @@ impl MakeKeyValueStore for AwsDynamoKeyValueStore {
5053
secret_key,
5154
token,
5255
region,
56+
consistent_read,
5357
table,
5458
} = runtime_config;
5559
let auth_options = match (access_key, secret_key) {
@@ -60,6 +64,11 @@ impl MakeKeyValueStore for AwsDynamoKeyValueStore {
6064
}
6165
_ => KeyValueAwsDynamoAuthOptions::Environmental,
6266
};
63-
KeyValueAwsDynamo::new(region, table, auth_options)
67+
KeyValueAwsDynamo::new(
68+
region,
69+
consistent_read.unwrap_or(false),
70+
table,
71+
auth_options,
72+
)
6473
}
6574
}

crates/key-value-aws/src/store.rs

Lines changed: 119 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,23 @@ use aws_sdk_dynamodb::{
1010
config::{ProvideCredentials, SharedCredentialsProvider},
1111
operation::{
1212
batch_get_item::BatchGetItemOutput, batch_write_item::BatchWriteItemOutput,
13-
get_item::GetItemOutput, update_item::UpdateItemOutput,
13+
get_item::GetItemOutput,
1414
},
1515
primitives::Blob,
16-
types::{
17-
AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
18-
WriteRequest,
19-
},
16+
types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest},
2017
Client,
2118
};
2219
use spin_core::async_trait;
2320
use spin_factor_key_value::{log_error, Cas, Error, Store, StoreManager, SwapError};
2421

2522
pub struct KeyValueAwsDynamo {
23+
/// AWS region
2624
region: String,
27-
// Needs to be cloned when getting a store
25+
/// Whether to use strongly consistent reads
26+
consistent_read: bool,
27+
/// DynamoDB table, needs to be cloned when getting a store
2828
table: Arc<String>,
29+
/// DynamoDB client
2930
client: async_once_cell::Lazy<
3031
Client,
3132
std::pin::Pin<Box<dyn std::future::Future<Output = Client> + Send>>,
@@ -84,6 +85,7 @@ pub enum KeyValueAwsDynamoAuthOptions {
8485
impl KeyValueAwsDynamo {
8586
pub fn new(
8687
region: String,
88+
consistent_read: bool,
8789
table: String,
8890
auth_options: KeyValueAwsDynamoAuthOptions,
8991
) -> Result<Self> {
@@ -104,6 +106,7 @@ impl KeyValueAwsDynamo {
104106

105107
Ok(Self {
106108
region,
109+
consistent_read,
107110
table: Arc::new(table),
108111
client: async_once_cell::Lazy::from_future(client_fut),
109112
})
@@ -116,6 +119,7 @@ impl StoreManager for KeyValueAwsDynamo {
116119
Ok(Arc::new(AwsDynamoStore {
117120
client: self.client.get_unpin().await.clone(),
118121
table: self.table.clone(),
122+
consistent_read: self.consistent_read,
119123
}))
120124
}
121125

@@ -135,29 +139,43 @@ struct AwsDynamoStore {
135139
// Client wraps an Arc so should be low cost to clone
136140
client: Client,
137141
table: Arc<String>,
142+
consistent_read: bool,
143+
}
144+
145+
#[derive(Debug, Clone)]
146+
enum CasState {
147+
// Existing item with version
148+
Versioned(String),
149+
// Existing item without version
150+
Unversioned(Blob),
151+
// Item was null when fetched during `current`
152+
Unset,
153+
// Potentially new item -- `current` was never called to fetch version
154+
Unknown,
138155
}
139156

140157
struct CompareAndSwap {
141158
key: String,
142159
client: Client,
143160
table: Arc<String>,
144161
bucket_rep: u32,
145-
has_lock: Mutex<bool>,
162+
state: Mutex<CasState>,
146163
}
147164

148165
/// Primary key in DynamoDB items used for querying items
149166
const PK: &str = "PK";
150167
/// Value key in DynamoDB items storing item value as binary
151-
const VAL: &str = "val";
152-
/// Lock key in DynamoDB items used for atomic operations
153-
const LOCK: &str = "lock";
168+
const VAL: &str = "VAL";
169+
/// Version key in DynamoDB items used for atomic operations
170+
const VER: &str = "VER";
154171

155172
#[async_trait]
156173
impl Store for AwsDynamoStore {
157174
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, Error> {
158175
let response = self
159176
.client
160177
.get_item()
178+
.consistent_read(self.consistent_read)
161179
.table_name(self.table.as_str())
162180
.key(
163181
PK,
@@ -208,6 +226,7 @@ impl Store for AwsDynamoStore {
208226
let GetItemOutput { item, .. } = self
209227
.client
210228
.get_item()
229+
.consistent_read(self.consistent_read)
211230
.table_name(self.table.as_str())
212231
.key(
213232
PK,
@@ -228,8 +247,13 @@ impl Store for AwsDynamoStore {
228247
async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
229248
let mut results = Vec::with_capacity(keys.len());
230249

231-
let mut keys_and_attributes_builder =
232-
KeysAndAttributes::builder().projection_expression(format!("{PK},{VAL}"));
250+
if keys.is_empty() {
251+
return Ok(results);
252+
}
253+
254+
let mut keys_and_attributes_builder = KeysAndAttributes::builder()
255+
.projection_expression(format!("{PK},{VAL}"))
256+
.consistent_read(self.consistent_read);
233257
for key in keys {
234258
keys_and_attributes_builder = keys_and_attributes_builder.keys(HashMap::from_iter([(
235259
PK.to_owned(),
@@ -243,7 +267,7 @@ impl Store for AwsDynamoStore {
243267

244268
while request_items.is_some() {
245269
let BatchGetItemOutput {
246-
responses: Some(mut responses),
270+
responses,
247271
unprocessed_keys,
248272
..
249273
} = self
@@ -252,25 +276,21 @@ impl Store for AwsDynamoStore {
252276
.set_request_items(request_items)
253277
.send()
254278
.await
255-
.map_err(log_error)?
256-
else {
257-
return Err(Error::Other("No results".into()));
258-
};
279+
.map_err(log_error)?;
259280

260-
if let Some(items) = responses.remove(self.table.as_str()) {
281+
if let Some(items) =
282+
responses.and_then(|mut responses| responses.remove(self.table.as_str()))
283+
{
261284
for mut item in items {
262-
let Some(AttributeValue::S(pk)) = item.remove(PK) else {
263-
return Err(Error::Other(
264-
"Could not find 'PK' key on DynamoDB item".into(),
265-
));
266-
};
267-
let Some(AttributeValue::B(val)) = item.remove(VAL) else {
268-
return Err(Error::Other(
269-
"Could not find 'val' key on DynamoDB item".into(),
270-
));
271-
};
272-
273-
results.push((pk, Some(val.into_inner())));
285+
match (item.remove(PK), item.remove(VAL)) {
286+
(Some(AttributeValue::S(pk)), Some(AttributeValue::B(val))) => {
287+
results.push((pk, Some(val.into_inner())));
288+
}
289+
(Some(AttributeValue::S(pk)), None) => {
290+
results.push((pk, None));
291+
}
292+
_ => (),
293+
}
274294
}
275295
}
276296

@@ -355,8 +375,8 @@ impl Store for AwsDynamoStore {
355375
.update_item()
356376
.table_name(self.table.as_str())
357377
.key(PK, AttributeValue::S(key))
358-
.update_expression("ADD #val :delta")
359-
.expression_attribute_names("#val", VAL)
378+
.update_expression("ADD #VAL :delta")
379+
.expression_attribute_names("#VAL", VAL)
360380
.expression_attribute_values(":delta", AttributeValue::N(delta.to_string()))
361381
.return_values(aws_sdk_dynamodb::types::ReturnValue::UpdatedNew)
362382
.send()
@@ -381,7 +401,7 @@ impl Store for AwsDynamoStore {
381401
key: key.to_string(),
382402
client: self.client.clone(),
383403
table: self.table.clone(),
384-
has_lock: Mutex::new(false),
404+
state: Mutex::new(CasState::Unknown),
385405
bucket_rep,
386406
}))
387407
}
@@ -390,60 +410,80 @@ impl Store for AwsDynamoStore {
390410
#[async_trait]
391411
impl Cas for CompareAndSwap {
392412
async fn current(&self) -> Result<Option<Vec<u8>>, Error> {
393-
let UpdateItemOutput { attributes, .. } = self
413+
let GetItemOutput { item, .. } = self
394414
.client
395-
.update_item()
415+
.get_item()
416+
.consistent_read(true)
396417
.table_name(self.table.as_str())
397418
.key(PK, AttributeValue::S(self.key.clone()))
398-
.update_expression("SET #lock=:lock")
399-
.expression_attribute_names("#lock", LOCK)
400-
.expression_attribute_values(":lock", AttributeValue::Null(true))
401-
.condition_expression("attribute_not_exists (#lock)")
402-
.return_values(aws_sdk_dynamodb::types::ReturnValue::AllNew)
419+
.projection_expression(format!("{VAL},{VER}"))
403420
.send()
404421
.await
405422
.map_err(log_error)?;
406423

407-
self.has_lock.lock().unwrap().clone_from(&true);
424+
match item {
425+
Some(mut current_item) => match (current_item.remove(VAL), current_item.remove(VER)) {
426+
(Some(AttributeValue::B(val)), Some(AttributeValue::N(ver))) => {
427+
self.state
428+
.lock()
429+
.unwrap()
430+
.clone_from(&CasState::Versioned(ver));
431+
432+
Ok(Some(val.into_inner()))
433+
}
434+
(Some(AttributeValue::B(val)), _) => {
435+
self.state
436+
.lock()
437+
.unwrap()
438+
.clone_from(&CasState::Unversioned(val.clone()));
408439

409-
match attributes {
410-
Some(mut item) => match item.remove(VAL) {
411-
Some(AttributeValue::B(val)) => Ok(Some(val.into_inner())),
412-
_ => Ok(None),
440+
Ok(Some(val.into_inner()))
441+
}
442+
(_, _) => {
443+
self.state.lock().unwrap().clone_from(&CasState::Unset);
444+
Ok(None)
445+
}
413446
},
414-
None => Ok(None),
447+
None => {
448+
self.state.lock().unwrap().clone_from(&CasState::Unset);
449+
Ok(None)
450+
}
415451
}
416452
}
417453

418-
/// `swap` updates the value for the key using the version saved in the `current` function for
419-
/// optimistic concurrency.
454+
/// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
455+
/// optimistic concurrency or the previous item value
420456
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
421-
let mut update_item = Update::builder()
457+
let mut update_item = self
458+
.client
459+
.update_item()
422460
.table_name(self.table.as_str())
423461
.key(PK, AttributeValue::S(self.key.clone()))
424-
.update_expression("SET #val=:val REMOVE #lock")
425-
.expression_attribute_names("#val", VAL)
462+
.update_expression("SET #VAL = :val ADD #VER :increment")
463+
.expression_attribute_names("#VAL", VAL)
464+
.expression_attribute_names("#VER", VER)
426465
.expression_attribute_values(":val", AttributeValue::B(Blob::new(value)))
427-
.expression_attribute_names("#lock", LOCK);
428-
429-
let has_lock = *self.has_lock.lock().unwrap();
430-
// Ensure exclusive access between fetching the current value of the item and swapping
431-
if has_lock {
432-
update_item = update_item.condition_expression("attribute_exists (#lock)");
433-
}
466+
.expression_attribute_values(":increment", AttributeValue::N("1".to_owned()));
467+
468+
let state = self.state.lock().unwrap().clone();
469+
match state {
470+
CasState::Versioned(version) => {
471+
update_item = update_item
472+
.condition_expression("#VER = :ver")
473+
.expression_attribute_values(":ver", AttributeValue::N(version));
474+
}
475+
CasState::Unversioned(old_val) => {
476+
update_item = update_item
477+
.condition_expression("#VAL = :old_val")
478+
.expression_attribute_values(":old_val", AttributeValue::B(old_val));
479+
}
480+
CasState::Unset => {
481+
update_item = update_item.condition_expression("attribute_not_exists (#VAL)");
482+
}
483+
CasState::Unknown => (),
484+
};
434485

435-
// TransactWriteItems fails if concurrent writes are in progress on an item, so even without locking, we get atomicity in overwriting
436-
self.client
437-
.transact_write_items()
438-
.transact_items(
439-
TransactWriteItem::builder()
440-
.update(
441-
update_item
442-
.build()
443-
.map_err(|e| SwapError::Other(format!("{e:?}")))?,
444-
)
445-
.build(),
446-
)
486+
update_item
447487
.send()
448488
.await
449489
.map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
@@ -463,35 +503,24 @@ impl Cas for CompareAndSwap {
463503
impl AwsDynamoStore {
464504
async fn get_keys(&self) -> Result<Vec<String>, Error> {
465505
let mut primary_keys = Vec::new();
466-
let mut last_evaluated_key = None;
467506

468-
loop {
469-
let mut scan_builder = self
470-
.client
471-
.scan()
472-
.table_name(self.table.as_str())
473-
.projection_expression(PK);
474-
475-
if let Some(keys) = last_evaluated_key {
476-
for (key, val) in keys {
477-
scan_builder = scan_builder.exclusive_start_key(key, val);
478-
}
479-
}
480-
481-
let scan_output = scan_builder.send().await.map_err(log_error)?;
507+
let mut scan_paginator = self
508+
.client
509+
.scan()
510+
.table_name(self.table.as_str())
511+
.projection_expression(PK)
512+
.into_paginator()
513+
.send();
482514

515+
while let Some(output) = scan_paginator.next().await {
516+
let scan_output = output.map_err(log_error)?;
483517
if let Some(items) = scan_output.items {
484518
for mut item in items {
485519
if let Some(AttributeValue::S(pk)) = item.remove(PK) {
486520
primary_keys.push(pk);
487521
}
488522
}
489523
}
490-
491-
last_evaluated_key = scan_output.last_evaluated_key;
492-
if last_evaluated_key.is_none() {
493-
break;
494-
}
495524
}
496525

497526
Ok(primary_keys)

0 commit comments

Comments
 (0)