@@ -10,22 +10,23 @@ use aws_sdk_dynamodb::{
1010 config:: { ProvideCredentials , SharedCredentialsProvider } ,
1111 operation:: {
1212 batch_get_item:: BatchGetItemOutput , batch_write_item:: BatchWriteItemOutput ,
13- get_item:: GetItemOutput , update_item :: UpdateItemOutput ,
13+ get_item:: GetItemOutput ,
1414 } ,
1515 primitives:: Blob ,
16- types:: {
17- AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , TransactWriteItem , Update ,
18- WriteRequest ,
19- } ,
16+ types:: { AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , WriteRequest } ,
2017 Client ,
2118} ;
2219use spin_core:: async_trait;
2320use spin_factor_key_value:: { log_error, Cas , Error , Store , StoreManager , SwapError } ;
2421
2522pub struct KeyValueAwsDynamo {
23+ /// AWS region
2624 region : String ,
27- // Needs to be cloned when getting a store
25+ /// Whether to use strongly consistent reads
26+ consistent_read : bool ,
27+ /// DynamoDB table, needs to be cloned when getting a store
2828 table : Arc < String > ,
29+ /// DynamoDB client
2930 client : async_once_cell:: Lazy <
3031 Client ,
3132 std:: pin:: Pin < Box < dyn std:: future:: Future < Output = Client > + Send > > ,
@@ -84,6 +85,7 @@ pub enum KeyValueAwsDynamoAuthOptions {
8485impl KeyValueAwsDynamo {
8586 pub fn new (
8687 region : String ,
88+ consistent_read : bool ,
8789 table : String ,
8890 auth_options : KeyValueAwsDynamoAuthOptions ,
8991 ) -> Result < Self > {
@@ -104,6 +106,7 @@ impl KeyValueAwsDynamo {
104106
105107 Ok ( Self {
106108 region,
109+ consistent_read,
107110 table : Arc :: new ( table) ,
108111 client : async_once_cell:: Lazy :: from_future ( client_fut) ,
109112 } )
@@ -116,6 +119,7 @@ impl StoreManager for KeyValueAwsDynamo {
116119 Ok ( Arc :: new ( AwsDynamoStore {
117120 client : self . client . get_unpin ( ) . await . clone ( ) ,
118121 table : self . table . clone ( ) ,
122+ consistent_read : self . consistent_read ,
119123 } ) )
120124 }
121125
@@ -135,29 +139,43 @@ struct AwsDynamoStore {
135139 // Client wraps an Arc so should be low cost to clone
136140 client : Client ,
137141 table : Arc < String > ,
142+ consistent_read : bool ,
143+ }
144+
145+ #[ derive( Debug , Clone ) ]
146+ enum CasState {
147+ // Existing item with version
148+ Versioned ( String ) ,
149+ // Existing item without version
150+ Unversioned ( Blob ) ,
151+ // Item was null when fetched during `current`
152+ Unset ,
153+ // Potentially new item -- `current` was never called to fetch version
154+ Unknown ,
138155}
139156
140157struct CompareAndSwap {
141158 key : String ,
142159 client : Client ,
143160 table : Arc < String > ,
144161 bucket_rep : u32 ,
145- has_lock : Mutex < bool > ,
162+ state : Mutex < CasState > ,
146163}
147164
148165/// Primary key in DynamoDB items used for querying items
149166const PK : & str = "PK" ;
150167/// Value key in DynamoDB items storing item value as binary
151- const VAL : & str = "val " ;
152- /// Lock key in DynamoDB items used for atomic operations
153- const LOCK : & str = "lock " ;
168+ const VAL : & str = "VAL " ;
169+ /// Version key in DynamoDB items used for atomic operations
170+ const VER : & str = "VER " ;
154171
155172#[ async_trait]
156173impl Store for AwsDynamoStore {
157174 async fn get ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Error > {
158175 let response = self
159176 . client
160177 . get_item ( )
178+ . consistent_read ( self . consistent_read )
161179 . table_name ( self . table . as_str ( ) )
162180 . key (
163181 PK ,
@@ -208,6 +226,7 @@ impl Store for AwsDynamoStore {
208226 let GetItemOutput { item, .. } = self
209227 . client
210228 . get_item ( )
229+ . consistent_read ( self . consistent_read )
211230 . table_name ( self . table . as_str ( ) )
212231 . key (
213232 PK ,
@@ -228,8 +247,13 @@ impl Store for AwsDynamoStore {
228247 async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
229248 let mut results = Vec :: with_capacity ( keys. len ( ) ) ;
230249
231- let mut keys_and_attributes_builder =
232- KeysAndAttributes :: builder ( ) . projection_expression ( format ! ( "{PK},{VAL}" ) ) ;
250+ if keys. is_empty ( ) {
251+ return Ok ( results) ;
252+ }
253+
254+ let mut keys_and_attributes_builder = KeysAndAttributes :: builder ( )
255+ . projection_expression ( format ! ( "{PK},{VAL}" ) )
256+ . consistent_read ( self . consistent_read ) ;
233257 for key in keys {
234258 keys_and_attributes_builder = keys_and_attributes_builder. keys ( HashMap :: from_iter ( [ (
235259 PK . to_owned ( ) ,
@@ -243,7 +267,7 @@ impl Store for AwsDynamoStore {
243267
244268 while request_items. is_some ( ) {
245269 let BatchGetItemOutput {
246- responses : Some ( mut responses ) ,
270+ responses,
247271 unprocessed_keys,
248272 ..
249273 } = self
@@ -252,25 +276,21 @@ impl Store for AwsDynamoStore {
252276 . set_request_items ( request_items)
253277 . send ( )
254278 . await
255- . map_err ( log_error) ?
256- else {
257- return Err ( Error :: Other ( "No results" . into ( ) ) ) ;
258- } ;
279+ . map_err ( log_error) ?;
259280
260- if let Some ( items) = responses. remove ( self . table . as_str ( ) ) {
281+ if let Some ( items) =
282+ responses. and_then ( |mut responses| responses. remove ( self . table . as_str ( ) ) )
283+ {
261284 for mut item in items {
262- let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) else {
263- return Err ( Error :: Other (
264- "Could not find 'PK' key on DynamoDB item" . into ( ) ,
265- ) ) ;
266- } ;
267- let Some ( AttributeValue :: B ( val) ) = item. remove ( VAL ) else {
268- return Err ( Error :: Other (
269- "Could not find 'val' key on DynamoDB item" . into ( ) ,
270- ) ) ;
271- } ;
272-
273- results. push ( ( pk, Some ( val. into_inner ( ) ) ) ) ;
285+ match ( item. remove ( PK ) , item. remove ( VAL ) ) {
286+ ( Some ( AttributeValue :: S ( pk) ) , Some ( AttributeValue :: B ( val) ) ) => {
287+ results. push ( ( pk, Some ( val. into_inner ( ) ) ) ) ;
288+ }
289+ ( Some ( AttributeValue :: S ( pk) ) , None ) => {
290+ results. push ( ( pk, None ) ) ;
291+ }
292+ _ => ( ) ,
293+ }
274294 }
275295 }
276296
@@ -355,8 +375,8 @@ impl Store for AwsDynamoStore {
355375 . update_item ( )
356376 . table_name ( self . table . as_str ( ) )
357377 . key ( PK , AttributeValue :: S ( key) )
358- . update_expression ( "ADD #val :delta" )
359- . expression_attribute_names ( "#val " , VAL )
378+ . update_expression ( "ADD #VAL :delta" )
379+ . expression_attribute_names ( "#VAL " , VAL )
360380 . expression_attribute_values ( ":delta" , AttributeValue :: N ( delta. to_string ( ) ) )
361381 . return_values ( aws_sdk_dynamodb:: types:: ReturnValue :: UpdatedNew )
362382 . send ( )
@@ -381,7 +401,7 @@ impl Store for AwsDynamoStore {
381401 key : key. to_string ( ) ,
382402 client : self . client . clone ( ) ,
383403 table : self . table . clone ( ) ,
384- has_lock : Mutex :: new ( false ) ,
404+ state : Mutex :: new ( CasState :: Unknown ) ,
385405 bucket_rep,
386406 } ) )
387407 }
@@ -390,60 +410,80 @@ impl Store for AwsDynamoStore {
390410#[ async_trait]
391411impl Cas for CompareAndSwap {
392412 async fn current ( & self ) -> Result < Option < Vec < u8 > > , Error > {
393- let UpdateItemOutput { attributes , .. } = self
413+ let GetItemOutput { item , .. } = self
394414 . client
395- . update_item ( )
415+ . get_item ( )
416+ . consistent_read ( true )
396417 . table_name ( self . table . as_str ( ) )
397418 . key ( PK , AttributeValue :: S ( self . key . clone ( ) ) )
398- . update_expression ( "SET #lock=:lock" )
399- . expression_attribute_names ( "#lock" , LOCK )
400- . expression_attribute_values ( ":lock" , AttributeValue :: Null ( true ) )
401- . condition_expression ( "attribute_not_exists (#lock)" )
402- . return_values ( aws_sdk_dynamodb:: types:: ReturnValue :: AllNew )
419+ . projection_expression ( format ! ( "{VAL},{VER}" ) )
403420 . send ( )
404421 . await
405422 . map_err ( log_error) ?;
406423
407- self . has_lock . lock ( ) . unwrap ( ) . clone_from ( & true ) ;
424+ match item {
425+ Some ( mut current_item) => match ( current_item. remove ( VAL ) , current_item. remove ( VER ) ) {
426+ ( Some ( AttributeValue :: B ( val) ) , Some ( AttributeValue :: N ( ver) ) ) => {
427+ self . state
428+ . lock ( )
429+ . unwrap ( )
430+ . clone_from ( & CasState :: Versioned ( ver) ) ;
431+
432+ Ok ( Some ( val. into_inner ( ) ) )
433+ }
434+ ( Some ( AttributeValue :: B ( val) ) , _) => {
435+ self . state
436+ . lock ( )
437+ . unwrap ( )
438+ . clone_from ( & CasState :: Unversioned ( val. clone ( ) ) ) ;
408439
409- match attributes {
410- Some ( mut item) => match item. remove ( VAL ) {
411- Some ( AttributeValue :: B ( val) ) => Ok ( Some ( val. into_inner ( ) ) ) ,
412- _ => Ok ( None ) ,
440+ Ok ( Some ( val. into_inner ( ) ) )
441+ }
442+ ( _, _) => {
443+ self . state . lock ( ) . unwrap ( ) . clone_from ( & CasState :: Unset ) ;
444+ Ok ( None )
445+ }
413446 } ,
414- None => Ok ( None ) ,
447+ None => {
448+ self . state . lock ( ) . unwrap ( ) . clone_from ( & CasState :: Unset ) ;
449+ Ok ( None )
450+ }
415451 }
416452 }
417453
418- /// `swap` updates the value for the key using the version saved in the `current` function for
419- /// optimistic concurrency.
454+ /// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
455+ /// optimistic concurrency or the previous item value
420456 async fn swap ( & self , value : Vec < u8 > ) -> Result < ( ) , SwapError > {
421- let mut update_item = Update :: builder ( )
457+ let mut update_item = self
458+ . client
459+ . update_item ( )
422460 . table_name ( self . table . as_str ( ) )
423461 . key ( PK , AttributeValue :: S ( self . key . clone ( ) ) )
424- . update_expression ( "SET #val=:val REMOVE #lock" )
425- . expression_attribute_names ( "#val" , VAL )
462+ . update_expression ( "SET #VAL = :val ADD #VER :increment" )
463+ . expression_attribute_names ( "#VAL" , VAL )
464+ . expression_attribute_names ( "#VER" , VER )
426465 . expression_attribute_values ( ":val" , AttributeValue :: B ( Blob :: new ( value) ) )
427- . expression_attribute_names ( "#lock" , LOCK ) ;
428-
429- let has_lock = * self . has_lock . lock ( ) . unwrap ( ) ;
430- // Ensure exclusive access between fetching the current value of the item and swapping
431- if has_lock {
432- update_item = update_item. condition_expression ( "attribute_exists (#lock)" ) ;
433- }
466+ . expression_attribute_values ( ":increment" , AttributeValue :: N ( "1" . to_owned ( ) ) ) ;
467+
468+ let state = self . state . lock ( ) . unwrap ( ) . clone ( ) ;
469+ match state {
470+ CasState :: Versioned ( version) => {
471+ update_item = update_item
472+ . condition_expression ( "#VER = :ver" )
473+ . expression_attribute_values ( ":ver" , AttributeValue :: N ( version) ) ;
474+ }
475+ CasState :: Unversioned ( old_val) => {
476+ update_item = update_item
477+ . condition_expression ( "#VAL = :old_val" )
478+ . expression_attribute_values ( ":old_val" , AttributeValue :: B ( old_val) ) ;
479+ }
480+ CasState :: Unset => {
481+ update_item = update_item. condition_expression ( "attribute_not_exists (#VAL)" ) ;
482+ }
483+ CasState :: Unknown => ( ) ,
484+ } ;
434485
435- // TransactWriteItems fails if concurrent writes are in progress on an item, so even without locking, we get atomicity in overwriting
436- self . client
437- . transact_write_items ( )
438- . transact_items (
439- TransactWriteItem :: builder ( )
440- . update (
441- update_item
442- . build ( )
443- . map_err ( |e| SwapError :: Other ( format ! ( "{e:?}" ) ) ) ?,
444- )
445- . build ( ) ,
446- )
486+ update_item
447487 . send ( )
448488 . await
449489 . map_err ( |e| SwapError :: CasFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -463,35 +503,24 @@ impl Cas for CompareAndSwap {
463503impl AwsDynamoStore {
464504 async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
465505 let mut primary_keys = Vec :: new ( ) ;
466- let mut last_evaluated_key = None ;
467506
468- loop {
469- let mut scan_builder = self
470- . client
471- . scan ( )
472- . table_name ( self . table . as_str ( ) )
473- . projection_expression ( PK ) ;
474-
475- if let Some ( keys) = last_evaluated_key {
476- for ( key, val) in keys {
477- scan_builder = scan_builder. exclusive_start_key ( key, val) ;
478- }
479- }
480-
481- let scan_output = scan_builder. send ( ) . await . map_err ( log_error) ?;
507+ let mut scan_paginator = self
508+ . client
509+ . scan ( )
510+ . table_name ( self . table . as_str ( ) )
511+ . projection_expression ( PK )
512+ . into_paginator ( )
513+ . send ( ) ;
482514
515+ while let Some ( output) = scan_paginator. next ( ) . await {
516+ let scan_output = output. map_err ( log_error) ?;
483517 if let Some ( items) = scan_output. items {
484518 for mut item in items {
485519 if let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) {
486520 primary_keys. push ( pk) ;
487521 }
488522 }
489523 }
490-
491- last_evaluated_key = scan_output. last_evaluated_key ;
492- if last_evaluated_key. is_none ( ) {
493- break ;
494- }
495524 }
496525
497526 Ok ( primary_keys)
0 commit comments