1+ use core:: str;
12use std:: {
23 collections:: HashMap ,
34 sync:: { Arc , Mutex } ,
@@ -13,7 +14,10 @@ use aws_sdk_dynamodb::{
1314 get_item:: GetItemOutput ,
1415 } ,
1516 primitives:: Blob ,
16- types:: { AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , WriteRequest } ,
17+ types:: {
18+ AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , TransactWriteItem , Update ,
19+ WriteRequest ,
20+ } ,
1721 Client ,
1822} ;
1923use spin_core:: async_trait;
@@ -148,7 +152,7 @@ enum CasState {
148152 Versioned ( String ) ,
149153 // Existing item without version
150154 Unversioned ( Blob ) ,
151- // Item was null when fetched during `current`
155+ // Item was missing when fetched during `current`, expected to be new
152156 Unset ,
153157 // Potentially new item -- `current` was never called to fetch version
154158 Unknown ,
@@ -210,15 +214,13 @@ impl Store for AwsDynamoStore {
210214 }
211215
212216 async fn delete ( & self , key : & str ) -> Result < ( ) , Error > {
213- if self . exists ( key) . await ? {
214- self . client
215- . delete_item ( )
216- . table_name ( self . table . as_str ( ) )
217- . key ( PK , AttributeValue :: S ( key. to_string ( ) ) )
218- . send ( )
219- . await
220- . map_err ( log_error) ?;
221- }
217+ self . client
218+ . delete_item ( )
219+ . table_name ( self . table . as_str ( ) )
220+ . key ( PK , AttributeValue :: S ( key. to_string ( ) ) )
221+ . send ( )
222+ . await
223+ . map_err ( log_error) ?;
222224 Ok ( ( ) )
223225 }
224226
@@ -241,16 +243,32 @@ impl Store for AwsDynamoStore {
241243 }
242244
243245 async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
244- self . get_keys ( ) . await
245- }
246+ let mut primary_keys = Vec :: new ( ) ;
246247
247- async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
248- let mut results = Vec :: with_capacity ( keys. len ( ) ) ;
248+ let mut scan_paginator = self
249+ . client
250+ . scan ( )
251+ . table_name ( self . table . as_str ( ) )
252+ . projection_expression ( PK )
253+ . into_paginator ( )
254+ . send ( ) ;
249255
250- if keys. is_empty ( ) {
251- return Ok ( results) ;
256+ while let Some ( output) = scan_paginator. next ( ) . await {
257+ let scan_output = output. map_err ( log_error) ?;
258+ if let Some ( items) = scan_output. items {
259+ for mut item in items {
260+ if let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) {
261+ primary_keys. push ( pk) ;
262+ }
263+ }
264+ }
252265 }
253266
267+ Ok ( primary_keys)
268+ }
269+
270+ async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
271+ let mut results = Vec :: with_capacity ( keys. len ( ) ) ;
254272 let mut keys_and_attributes_builder = KeysAndAttributes :: builder ( )
255273 . projection_expression ( format ! ( "{PK},{VAL}" ) )
256274 . consistent_read ( self . consistent_read ) ;
@@ -370,26 +388,66 @@ impl Store for AwsDynamoStore {
370388 }
371389
372390 async fn increment ( & self , key : String , delta : i64 ) -> Result < i64 , Error > {
373- let result = self
391+ let GetItemOutput { item , .. } = self
374392 . client
375- . update_item ( )
393+ . get_item ( )
394+ . consistent_read ( true )
376395 . table_name ( self . table . as_str ( ) )
377- . key ( PK , AttributeValue :: S ( key) )
378- . update_expression ( "ADD #VAL :delta" )
379- . expression_attribute_names ( "#VAL" , VAL )
380- . expression_attribute_values ( ":delta" , AttributeValue :: N ( delta. to_string ( ) ) )
381- . return_values ( aws_sdk_dynamodb:: types:: ReturnValue :: UpdatedNew )
396+ . key ( PK , AttributeValue :: S ( key. clone ( ) ) )
397+ . projection_expression ( VAL )
382398 . send ( )
383399 . await
384400 . map_err ( log_error) ?;
385401
386- if let Some ( updated_attributes) = result. attributes {
387- if let Some ( AttributeValue :: N ( new_value) ) = updated_attributes. get ( VAL ) {
388- return Ok ( new_value. parse :: < i64 > ( ) . map_err ( log_error) ) ?;
389- }
402+ let old_val = match item {
403+ Some ( mut current_item) => match current_item. remove ( VAL ) {
404+ // We're expecting i64, so technically we could transmute but seems risky...
405+ Some ( AttributeValue :: B ( val) ) => Some (
406+ str:: from_utf8 ( & val. into_inner ( ) )
407+ . map_err ( log_error) ?
408+ . parse :: < i64 > ( )
409+ . map_err ( log_error) ?,
410+ ) ,
411+ _ => None ,
412+ } ,
413+ None => None ,
414+ } ;
415+
416+ let new_val = old_val. unwrap_or ( 0 ) + delta;
417+
418+ let mut update = Update :: builder ( )
419+ . table_name ( self . table . as_str ( ) )
420+ . key ( PK , AttributeValue :: S ( key) )
421+ . update_expression ( "SET #VAL = :new_val" )
422+ . expression_attribute_names ( "#VAL" , VAL )
423+ . expression_attribute_values (
424+ ":new_val" ,
425+ AttributeValue :: B ( Blob :: new ( new_val. to_string ( ) . as_bytes ( ) ) ) ,
426+ ) ;
427+
428+ if let Some ( old_val) = old_val {
429+ update = update
430+ . condition_expression ( "#VAL = :old_val" )
431+ . expression_attribute_values (
432+ ":old_val" ,
433+ AttributeValue :: B ( Blob :: new ( old_val. to_string ( ) . as_bytes ( ) ) ) ,
434+ )
435+ } else {
436+ update = update. condition_expression ( "attribute_not_exists (#VAL)" )
390437 }
391438
392- Err ( Error :: Other ( "Failed to increment value" . into ( ) ) )
439+ self . client
440+ . transact_write_items ( )
441+ . transact_items (
442+ TransactWriteItem :: builder ( )
443+ . update ( update. build ( ) . map_err ( log_error) ?)
444+ . build ( ) ,
445+ )
446+ . send ( )
447+ . await
448+ . map_err ( log_error) ?;
449+
450+ Ok ( new_val)
393451 }
394452
395453 async fn new_compare_and_swap (
@@ -454,9 +512,7 @@ impl Cas for CompareAndSwap {
454512 /// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
455513 /// optimistic concurrency or the previous item value
456514 async fn swap ( & self , value : Vec < u8 > ) -> Result < ( ) , SwapError > {
457- let mut update_item = self
458- . client
459- . update_item ( )
515+ let mut update = Update :: builder ( )
460516 . table_name ( self . table . as_str ( ) )
461517 . key ( PK , AttributeValue :: S ( self . key . clone ( ) ) )
462518 . update_expression ( "SET #VAL = :val ADD #VER :increment" )
@@ -468,22 +524,32 @@ impl Cas for CompareAndSwap {
468524 let state = self . state . lock ( ) . unwrap ( ) . clone ( ) ;
469525 match state {
470526 CasState :: Versioned ( version) => {
471- update_item = update_item
527+ update = update
472528 . condition_expression ( "#VER = :ver" )
473529 . expression_attribute_values ( ":ver" , AttributeValue :: N ( version) ) ;
474530 }
475531 CasState :: Unversioned ( old_val) => {
476- update_item = update_item
532+ update = update
477533 . condition_expression ( "#VAL = :old_val" )
478534 . expression_attribute_values ( ":old_val" , AttributeValue :: B ( old_val) ) ;
479535 }
480536 CasState :: Unset => {
481- update_item = update_item . condition_expression ( "attribute_not_exists (#VAL)" ) ;
537+ update = update . condition_expression ( "attribute_not_exists (#VAL)" ) ;
482538 }
483539 CasState :: Unknown => ( ) ,
484540 } ;
485541
486- update_item
542+ self . client
543+ . transact_write_items ( )
544+ . transact_items (
545+ TransactWriteItem :: builder ( )
546+ . update (
547+ update
548+ . build ( )
549+ . map_err ( |e| SwapError :: Other ( format ! ( "{e:?}" ) ) ) ?,
550+ )
551+ . build ( ) ,
552+ )
487553 . send ( )
488554 . await
489555 . map_err ( |e| SwapError :: CasFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -499,30 +565,3 @@ impl Cas for CompareAndSwap {
499565 self . key . clone ( )
500566 }
501567}
502-
503- impl AwsDynamoStore {
504- async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
505- let mut primary_keys = Vec :: new ( ) ;
506-
507- let mut scan_paginator = self
508- . client
509- . scan ( )
510- . table_name ( self . table . as_str ( ) )
511- . projection_expression ( PK )
512- . into_paginator ( )
513- . send ( ) ;
514-
515- while let Some ( output) = scan_paginator. next ( ) . await {
516- let scan_output = output. map_err ( log_error) ?;
517- if let Some ( items) = scan_output. items {
518- for mut item in items {
519- if let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) {
520- primary_keys. push ( pk) ;
521- }
522- }
523- }
524- }
525-
526- Ok ( primary_keys)
527- }
528- }
0 commit comments