1
+ use core:: str;
1
2
use std:: {
2
3
collections:: HashMap ,
3
4
sync:: { Arc , Mutex } ,
@@ -13,7 +14,10 @@ use aws_sdk_dynamodb::{
13
14
get_item:: GetItemOutput ,
14
15
} ,
15
16
primitives:: Blob ,
16
- types:: { AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , WriteRequest } ,
17
+ types:: {
18
+ AttributeValue , DeleteRequest , KeysAndAttributes , PutRequest , TransactWriteItem , Update ,
19
+ WriteRequest ,
20
+ } ,
17
21
Client ,
18
22
} ;
19
23
use spin_core:: async_trait;
@@ -148,7 +152,7 @@ enum CasState {
148
152
Versioned ( String ) ,
149
153
// Existing item without version
150
154
Unversioned ( Blob ) ,
151
- // Item was null when fetched during `current`
155
+ // Item was missing when fetched during `current`, expected to be new
152
156
Unset ,
153
157
// Potentially new item -- `current` was never called to fetch version
154
158
Unknown ,
@@ -210,15 +214,13 @@ impl Store for AwsDynamoStore {
210
214
}
211
215
212
216
async fn delete ( & self , key : & str ) -> Result < ( ) , Error > {
213
- if self . exists ( key) . await ? {
214
- self . client
215
- . delete_item ( )
216
- . table_name ( self . table . as_str ( ) )
217
- . key ( PK , AttributeValue :: S ( key. to_string ( ) ) )
218
- . send ( )
219
- . await
220
- . map_err ( log_error) ?;
221
- }
217
+ self . client
218
+ . delete_item ( )
219
+ . table_name ( self . table . as_str ( ) )
220
+ . key ( PK , AttributeValue :: S ( key. to_string ( ) ) )
221
+ . send ( )
222
+ . await
223
+ . map_err ( log_error) ?;
222
224
Ok ( ( ) )
223
225
}
224
226
@@ -241,16 +243,32 @@ impl Store for AwsDynamoStore {
241
243
}
242
244
243
245
async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
244
- self . get_keys ( ) . await
245
- }
246
+ let mut primary_keys = Vec :: new ( ) ;
246
247
247
- async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
248
- let mut results = Vec :: with_capacity ( keys. len ( ) ) ;
248
+ let mut scan_paginator = self
249
+ . client
250
+ . scan ( )
251
+ . table_name ( self . table . as_str ( ) )
252
+ . projection_expression ( PK )
253
+ . into_paginator ( )
254
+ . send ( ) ;
249
255
250
- if keys. is_empty ( ) {
251
- return Ok ( results) ;
256
+ while let Some ( output) = scan_paginator. next ( ) . await {
257
+ let scan_output = output. map_err ( log_error) ?;
258
+ if let Some ( items) = scan_output. items {
259
+ for mut item in items {
260
+ if let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) {
261
+ primary_keys. push ( pk) ;
262
+ }
263
+ }
264
+ }
252
265
}
253
266
267
+ Ok ( primary_keys)
268
+ }
269
+
270
+ async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
271
+ let mut results = Vec :: with_capacity ( keys. len ( ) ) ;
254
272
let mut keys_and_attributes_builder = KeysAndAttributes :: builder ( )
255
273
. projection_expression ( format ! ( "{PK},{VAL}" ) )
256
274
. consistent_read ( self . consistent_read ) ;
@@ -370,26 +388,66 @@ impl Store for AwsDynamoStore {
370
388
}
371
389
372
390
async fn increment ( & self , key : String , delta : i64 ) -> Result < i64 , Error > {
373
- let result = self
391
+ let GetItemOutput { item , .. } = self
374
392
. client
375
- . update_item ( )
393
+ . get_item ( )
394
+ . consistent_read ( true )
376
395
. table_name ( self . table . as_str ( ) )
377
- . key ( PK , AttributeValue :: S ( key) )
378
- . update_expression ( "ADD #VAL :delta" )
379
- . expression_attribute_names ( "#VAL" , VAL )
380
- . expression_attribute_values ( ":delta" , AttributeValue :: N ( delta. to_string ( ) ) )
381
- . return_values ( aws_sdk_dynamodb:: types:: ReturnValue :: UpdatedNew )
396
+ . key ( PK , AttributeValue :: S ( key. clone ( ) ) )
397
+ . projection_expression ( VAL )
382
398
. send ( )
383
399
. await
384
400
. map_err ( log_error) ?;
385
401
386
- if let Some ( updated_attributes) = result. attributes {
387
- if let Some ( AttributeValue :: N ( new_value) ) = updated_attributes. get ( VAL ) {
388
- return Ok ( new_value. parse :: < i64 > ( ) . map_err ( log_error) ) ?;
389
- }
402
+ let old_val = match item {
403
+ Some ( mut current_item) => match current_item. remove ( VAL ) {
404
+ // We're expecting i64, so technically we could transmute but seems risky...
405
+ Some ( AttributeValue :: B ( val) ) => Some (
406
+ str:: from_utf8 ( & val. into_inner ( ) )
407
+ . map_err ( log_error) ?
408
+ . parse :: < i64 > ( )
409
+ . map_err ( log_error) ?,
410
+ ) ,
411
+ _ => None ,
412
+ } ,
413
+ None => None ,
414
+ } ;
415
+
416
+ let new_val = old_val. unwrap_or ( 0 ) + delta;
417
+
418
+ let mut update = Update :: builder ( )
419
+ . table_name ( self . table . as_str ( ) )
420
+ . key ( PK , AttributeValue :: S ( key) )
421
+ . update_expression ( "SET #VAL = :new_val" )
422
+ . expression_attribute_names ( "#VAL" , VAL )
423
+ . expression_attribute_values (
424
+ ":new_val" ,
425
+ AttributeValue :: B ( Blob :: new ( new_val. to_string ( ) . as_bytes ( ) ) ) ,
426
+ ) ;
427
+
428
+ if let Some ( old_val) = old_val {
429
+ update = update
430
+ . condition_expression ( "#VAL = :old_val" )
431
+ . expression_attribute_values (
432
+ ":old_val" ,
433
+ AttributeValue :: B ( Blob :: new ( old_val. to_string ( ) . as_bytes ( ) ) ) ,
434
+ )
435
+ } else {
436
+ update = update. condition_expression ( "attribute_not_exists (#VAL)" )
390
437
}
391
438
392
- Err ( Error :: Other ( "Failed to increment value" . into ( ) ) )
439
+ self . client
440
+ . transact_write_items ( )
441
+ . transact_items (
442
+ TransactWriteItem :: builder ( )
443
+ . update ( update. build ( ) . map_err ( log_error) ?)
444
+ . build ( ) ,
445
+ )
446
+ . send ( )
447
+ . await
448
+ . map_err ( log_error) ?;
449
+
450
+ Ok ( new_val)
393
451
}
394
452
395
453
async fn new_compare_and_swap (
@@ -454,9 +512,7 @@ impl Cas for CompareAndSwap {
454
512
/// `swap` updates the value for the key -- if possible, using the version saved in the `current` function for
455
513
/// optimistic concurrency or the previous item value
456
514
async fn swap ( & self , value : Vec < u8 > ) -> Result < ( ) , SwapError > {
457
- let mut update_item = self
458
- . client
459
- . update_item ( )
515
+ let mut update = Update :: builder ( )
460
516
. table_name ( self . table . as_str ( ) )
461
517
. key ( PK , AttributeValue :: S ( self . key . clone ( ) ) )
462
518
. update_expression ( "SET #VAL = :val ADD #VER :increment" )
@@ -468,22 +524,32 @@ impl Cas for CompareAndSwap {
468
524
let state = self . state . lock ( ) . unwrap ( ) . clone ( ) ;
469
525
match state {
470
526
CasState :: Versioned ( version) => {
471
- update_item = update_item
527
+ update = update
472
528
. condition_expression ( "#VER = :ver" )
473
529
. expression_attribute_values ( ":ver" , AttributeValue :: N ( version) ) ;
474
530
}
475
531
CasState :: Unversioned ( old_val) => {
476
- update_item = update_item
532
+ update = update
477
533
. condition_expression ( "#VAL = :old_val" )
478
534
. expression_attribute_values ( ":old_val" , AttributeValue :: B ( old_val) ) ;
479
535
}
480
536
CasState :: Unset => {
481
- update_item = update_item . condition_expression ( "attribute_not_exists (#VAL)" ) ;
537
+ update = update . condition_expression ( "attribute_not_exists (#VAL)" ) ;
482
538
}
483
539
CasState :: Unknown => ( ) ,
484
540
} ;
485
541
486
- update_item
542
+ self . client
543
+ . transact_write_items ( )
544
+ . transact_items (
545
+ TransactWriteItem :: builder ( )
546
+ . update (
547
+ update
548
+ . build ( )
549
+ . map_err ( |e| SwapError :: Other ( format ! ( "{e:?}" ) ) ) ?,
550
+ )
551
+ . build ( ) ,
552
+ )
487
553
. send ( )
488
554
. await
489
555
. map_err ( |e| SwapError :: CasFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -499,30 +565,3 @@ impl Cas for CompareAndSwap {
499
565
self . key . clone ( )
500
566
}
501
567
}
502
-
503
- impl AwsDynamoStore {
504
- async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
505
- let mut primary_keys = Vec :: new ( ) ;
506
-
507
- let mut scan_paginator = self
508
- . client
509
- . scan ( )
510
- . table_name ( self . table . as_str ( ) )
511
- . projection_expression ( PK )
512
- . into_paginator ( )
513
- . send ( ) ;
514
-
515
- while let Some ( output) = scan_paginator. next ( ) . await {
516
- let scan_output = output. map_err ( log_error) ?;
517
- if let Some ( items) = scan_output. items {
518
- for mut item in items {
519
- if let Some ( AttributeValue :: S ( pk) ) = item. remove ( PK ) {
520
- primary_keys. push ( pk) ;
521
- }
522
- }
523
- }
524
- }
525
-
526
- Ok ( primary_keys)
527
- }
528
- }
0 commit comments