Skip to content

Commit 1cc913c

Browse files
committed
Use transaction in increment and swap for better atomicity, remove unneeded exists check, higher level filtering of empty get_all queries, sqlite handle null value before swap
Signed-off-by: Darwin Boersma <[email protected]>
1 parent ed7299f commit 1cc913c

File tree

4 files changed

+146
-81
lines changed

4 files changed

+146
-81
lines changed

crates/factor-key-value/src/host.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
283283
keys: Vec<String>,
284284
) -> std::result::Result<Vec<(String, Option<Vec<u8>>)>, wasi_keyvalue::store::Error> {
285285
let store = self.get_store_wasi(bucket)?;
286+
if keys.is_empty() {
287+
return Ok(vec![]);
288+
}
286289
store.get_many(keys).await.map_err(to_wasi_err)
287290
}
288291

@@ -293,6 +296,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
293296
key_values: Vec<(String, Vec<u8>)>,
294297
) -> std::result::Result<(), wasi_keyvalue::store::Error> {
295298
let store = self.get_store_wasi(bucket)?;
299+
if key_values.is_empty() {
300+
return Ok(());
301+
}
296302
store.set_many(key_values).await.map_err(to_wasi_err)
297303
}
298304

@@ -303,6 +309,9 @@ impl wasi_keyvalue::batch::Host for KeyValueDispatch {
303309
keys: Vec<String>,
304310
) -> std::result::Result<(), wasi_keyvalue::store::Error> {
305311
let store = self.get_store_wasi(bucket)?;
312+
if keys.is_empty() {
313+
return Ok(());
314+
}
306315
store.delete_many(keys).await.map_err(to_wasi_err)
307316
}
308317
}

crates/factor-key-value/src/util.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,12 @@ impl Store for CachingStore {
260260
}
261261
}
262262

263-
let keys_and_values = self.inner.get_many(not_found).await?;
264-
for (key, value) in keys_and_values {
265-
found.push((key.clone(), value.clone()));
266-
state.cache.put(key, value);
263+
if !not_found.is_empty() {
264+
let keys_and_values = self.inner.get_many(not_found).await?;
265+
for (key, value) in keys_and_values {
266+
found.push((key.clone(), value.clone()));
267+
state.cache.put(key, value);
268+
}
267269
}
268270

269271
Ok(found)

crates/key-value-aws/src/store.rs

Lines changed: 102 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use core::str;
12
use std::{
23
collections::HashMap,
34
sync::{Arc, Mutex},
@@ -13,7 +14,10 @@ use aws_sdk_dynamodb::{
1314
get_item::GetItemOutput,
1415
},
1516
primitives::Blob,
16-
types::{AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, WriteRequest},
17+
types::{
18+
AttributeValue, DeleteRequest, KeysAndAttributes, PutRequest, TransactWriteItem, Update,
19+
WriteRequest,
20+
},
1721
Client,
1822
};
1923
use spin_core::async_trait;
@@ -148,7 +152,7 @@ enum CasState {
148152
Versioned(String),
149153
// Existing item without version
150154
Unversioned(Blob),
151-
// Item was null when fetched during `current`
155+
// Item was missing when fetched during `current`, expected to be new
152156
Unset,
153157
// Potentially new item -- `current` was never called to fetch version
154158
Unknown,
@@ -210,15 +214,13 @@ impl Store for AwsDynamoStore {
210214
}
211215

212216
async fn delete(&self, key: &str) -> Result<(), Error> {
213-
if self.exists(key).await? {
214-
self.client
215-
.delete_item()
216-
.table_name(self.table.as_str())
217-
.key(PK, AttributeValue::S(key.to_string()))
218-
.send()
219-
.await
220-
.map_err(log_error)?;
221-
}
217+
self.client
218+
.delete_item()
219+
.table_name(self.table.as_str())
220+
.key(PK, AttributeValue::S(key.to_string()))
221+
.send()
222+
.await
223+
.map_err(log_error)?;
222224
Ok(())
223225
}
224226

@@ -241,16 +243,32 @@ impl Store for AwsDynamoStore {
241243
}
242244

243245
async fn get_keys(&self) -> Result<Vec<String>, Error> {
244-
self.get_keys().await
245-
}
246+
let mut primary_keys = Vec::new();
246247

247-
async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
248-
let mut results = Vec::with_capacity(keys.len());
248+
let mut scan_paginator = self
249+
.client
250+
.scan()
251+
.table_name(self.table.as_str())
252+
.projection_expression(PK)
253+
.into_paginator()
254+
.send();
249255

250-
if keys.is_empty() {
251-
return Ok(results);
256+
while let Some(output) = scan_paginator.next().await {
257+
let scan_output = output.map_err(log_error)?;
258+
if let Some(items) = scan_output.items {
259+
for mut item in items {
260+
if let Some(AttributeValue::S(pk)) = item.remove(PK) {
261+
primary_keys.push(pk);
262+
}
263+
}
264+
}
252265
}
253266

267+
Ok(primary_keys)
268+
}
269+
270+
async fn get_many(&self, keys: Vec<String>) -> Result<Vec<(String, Option<Vec<u8>>)>, Error> {
271+
let mut results = Vec::with_capacity(keys.len());
254272
let mut keys_and_attributes_builder = KeysAndAttributes::builder()
255273
.projection_expression(format!("{PK},{VAL}"))
256274
.consistent_read(self.consistent_read);
@@ -370,26 +388,66 @@ impl Store for AwsDynamoStore {
370388
}
371389

372390
async fn increment(&self, key: String, delta: i64) -> Result<i64, Error> {
373-
let result = self
391+
let GetItemOutput { item, .. } = self
374392
.client
375-
.update_item()
393+
.get_item()
394+
.consistent_read(true)
376395
.table_name(self.table.as_str())
377-
.key(PK, AttributeValue::S(key))
378-
.update_expression("ADD #VAL :delta")
379-
.expression_attribute_names("#VAL", VAL)
380-
.expression_attribute_values(":delta", AttributeValue::N(delta.to_string()))
381-
.return_values(aws_sdk_dynamodb::types::ReturnValue::UpdatedNew)
396+
.key(PK, AttributeValue::S(key.clone()))
397+
.projection_expression(VAL)
382398
.send()
383399
.await
384400
.map_err(log_error)?;
385401

386-
if let Some(updated_attributes) = result.attributes {
387-
if let Some(AttributeValue::N(new_value)) = updated_attributes.get(VAL) {
388-
return Ok(new_value.parse::<i64>().map_err(log_error))?;
389-
}
402+
let old_val = match item {
403+
Some(mut current_item) => match current_item.remove(VAL) {
404+
// We're expecting i64, so technically we could transmute but seems risky...
405+
Some(AttributeValue::B(val)) => Some(
406+
str::from_utf8(&val.into_inner())
407+
.map_err(log_error)?
408+
.parse::<i64>()
409+
.map_err(log_error)?,
410+
),
411+
_ => None,
412+
},
413+
None => None,
414+
};
415+
416+
let new_val = old_val.unwrap_or(0) + delta;
417+
418+
let mut update = Update::builder()
419+
.table_name(self.table.as_str())
420+
.key(PK, AttributeValue::S(key))
421+
.update_expression("SET #VAL = :new_val")
422+
.expression_attribute_names("#VAL", VAL)
423+
.expression_attribute_values(
424+
":new_val",
425+
AttributeValue::B(Blob::new(new_val.to_string().as_bytes())),
426+
);
427+
428+
if let Some(old_val) = old_val {
429+
update = update
430+
.condition_expression("#VAL = :old_val")
431+
.expression_attribute_values(
432+
":old_val",
433+
AttributeValue::B(Blob::new(old_val.to_string().as_bytes())),
434+
)
435+
} else {
436+
update = update.condition_expression("attribute_not_exists (#VAL)")
390437
}
391438

392-
Err(Error::Other("Failed to increment value".into()))
439+
self.client
440+
.transact_write_items()
441+
.transact_items(
442+
TransactWriteItem::builder()
443+
.update(update.build().map_err(log_error)?)
444+
.build(),
445+
)
446+
.send()
447+
.await
448+
.map_err(log_error)?;
449+
450+
Ok(new_val)
393451
}
394452

395453
async fn new_compare_and_swap(
@@ -454,9 +512,7 @@ impl Cas for CompareAndSwap {
454512
/// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
455513
/// optimistic concurrency or the previous item value
456514
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
457-
let mut update_item = self
458-
.client
459-
.update_item()
515+
let mut update = Update::builder()
460516
.table_name(self.table.as_str())
461517
.key(PK, AttributeValue::S(self.key.clone()))
462518
.update_expression("SET #VAL = :val ADD #VER :increment")
@@ -468,22 +524,32 @@ impl Cas for CompareAndSwap {
468524
let state = self.state.lock().unwrap().clone();
469525
match state {
470526
CasState::Versioned(version) => {
471-
update_item = update_item
527+
update = update
472528
.condition_expression("#VER = :ver")
473529
.expression_attribute_values(":ver", AttributeValue::N(version));
474530
}
475531
CasState::Unversioned(old_val) => {
476-
update_item = update_item
532+
update = update
477533
.condition_expression("#VAL = :old_val")
478534
.expression_attribute_values(":old_val", AttributeValue::B(old_val));
479535
}
480536
CasState::Unset => {
481-
update_item = update_item.condition_expression("attribute_not_exists (#VAL)");
537+
update = update.condition_expression("attribute_not_exists (#VAL)");
482538
}
483539
CasState::Unknown => (),
484540
};
485541

486-
update_item
542+
self.client
543+
.transact_write_items()
544+
.transact_items(
545+
TransactWriteItem::builder()
546+
.update(
547+
update
548+
.build()
549+
.map_err(|e| SwapError::Other(format!("{e:?}")))?,
550+
)
551+
.build(),
552+
)
487553
.send()
488554
.await
489555
.map_err(|e| SwapError::CasFailed(format!("{e:?}")))?;
@@ -499,30 +565,3 @@ impl Cas for CompareAndSwap {
499565
self.key.clone()
500566
}
501567
}
502-
503-
impl AwsDynamoStore {
504-
async fn get_keys(&self) -> Result<Vec<String>, Error> {
505-
let mut primary_keys = Vec::new();
506-
507-
let mut scan_paginator = self
508-
.client
509-
.scan()
510-
.table_name(self.table.as_str())
511-
.projection_expression(PK)
512-
.into_paginator()
513-
.send();
514-
515-
while let Some(output) = scan_paginator.next().await {
516-
let scan_output = output.map_err(log_error)?;
517-
if let Some(items) = scan_output.items {
518-
for mut item in items {
519-
if let Some(AttributeValue::S(pk)) = item.remove(PK) {
520-
primary_keys.push(pk);
521-
}
522-
}
523-
}
524-
}
525-
526-
Ok(primary_keys)
527-
}
528-
}

crates/key-value-spin/src/store.rs

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -307,20 +307,35 @@ impl Cas for CompareAndSwap {
307307
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
308308
task::block_in_place(|| {
309309
let old_value = self.value.lock().unwrap();
310-
let rows_changed = self.connection
311-
.lock()
312-
.unwrap()
313-
.prepare_cached(
314-
"UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value",
315-
)
316-
.map_err(log_cas_error)?
317-
.execute(named_params! {
318-
":name": &self.name,
319-
":key": self.key,
320-
":old_value": old_value.clone().unwrap(),
321-
":new_value": value,
322-
})
323-
.map_err(log_cas_error)?;
310+
let mut conn = self.connection.lock().unwrap();
311+
let rows_changed = match old_value.clone() {
312+
Some(old_val) => {
313+
conn
314+
.prepare_cached(
315+
"UPDATE spin_key_value SET value=:new_value WHERE store=:name and key=:key and value=:old_value")
316+
.map_err(log_cas_error)?
317+
.execute(named_params! {
318+
":name": &self.name,
319+
":key": self.key,
320+
":old_value": old_val,
321+
":new_value": value,
322+
})
323+
.map_err(log_cas_error)?
324+
}
325+
None => {
326+
let tx = conn.transaction().map_err(log_cas_error)?;
327+
let rows = tx
328+
.prepare_cached(
329+
"INSERT INTO spin_key_value (store, key, value) VALUES ($1, $2, $3)
330+
ON CONFLICT(store, key) DO UPDATE SET value=$3",
331+
)
332+
.map_err(log_cas_error)?
333+
.execute(rusqlite::params![&self.name, self.key, value])
334+
.map_err(log_cas_error)?;
335+
tx.commit().map_err(log_cas_error)?;
336+
rows
337+
}
338+
};
324339

325340
// We expect only 1 row to be updated. If 0, we know that the underlying value has changed.
326341
if rows_changed == 1 {

0 commit comments

Comments
 (0)