@@ -138,14 +138,18 @@ struct AzureCosmosStore {
138138 client : CollectionClient ,
139139 /// An optional store id to use as a partition key for all operations.
140140 ///
141- /// If the store id not set, the store will use `/id` as the partition key.
141+ /// If the store ID is not set, the store will use `/id` (the row key) as
142+ /// the partition key. For example, if `store.set("my_key", "my_value")` is
143+ /// called, the partition key will be `my_key` if the store ID is set to
144+ /// `None`. If the store ID is set to `Some("myappid/default"), the
145+ /// partition key will be `myappid/default`.
142146 store_id : Option < String > ,
143147}
144148
145149#[ async_trait]
146150impl Store for AzureCosmosStore {
147151 async fn get ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Error > {
148- let pair = self . get_pair ( key) . await ?;
152+ let pair = self . get_entity :: < Pair > ( key) . await ?;
149153 Ok ( pair. map ( |p| p. value ) )
150154 }
151155
@@ -164,18 +168,20 @@ impl Store for AzureCosmosStore {
164168 }
165169
166170 async fn delete ( & self , key : & str ) -> Result < ( ) , Error > {
167- if self . exists ( key) . await ? {
168- let document_client = self
169- . client
170- . document_client ( key, & self . store_id )
171- . map_err ( log_error) ?;
172- document_client. delete_document ( ) . await . map_err ( log_error) ?;
171+ let document_client = self
172+ . client
173+ . document_client ( key, & self . store_id . clone ( ) . unwrap_or ( key. to_string ( ) ) )
174+ . map_err ( log_error) ?;
175+ if let Err ( e) = document_client. delete_document ( ) . await {
176+ if e. as_http_error ( ) . map ( |e| e. status ( ) != 404 ) . unwrap_or ( true ) {
177+ return Err ( log_error ( e) ) ;
178+ }
173179 }
174180 Ok ( ( ) )
175181 }
176182
177183 async fn exists ( & self , key : & str ) -> Result < bool , Error > {
178- Ok ( self . get_pair ( key) . await ?. is_some ( ) )
184+ Ok ( self . get_entity :: < Key > ( key) . await ?. is_some ( ) )
179185 }
180186
181187 async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
@@ -216,24 +222,58 @@ impl Store for AzureCosmosStore {
216222 Ok ( ( ) )
217223 }
218224
225+ /// Increments a numerical value.
226+ ///
227+ /// The initial value for the item must be set through this interface, as this sets the
228+ /// number value if it does not exist. If the value was previously set using
229+ /// the `set` interface, this will fail due to a type mismatch.
230+ // TODO: The function should parse the new value from the return response
231+ // rather than sending an additional new request. However, the current SDK
232+ // version does not support this.
219233 async fn increment ( & self , key : String , delta : i64 ) -> Result < i64 , Error > {
220234 let operations = vec ! [ Operation :: incr( "/value" , delta) . map_err( log_error) ?] ;
221- let _ = self
235+ match self
222236 . client
223- . document_client ( key. clone ( ) , & self . store_id )
237+ . document_client ( & key, & self . store_id . clone ( ) . unwrap_or ( key . to_string ( ) ) )
224238 . map_err ( log_error) ?
225239 . patch_document ( operations)
226240 . await
227- . map_err ( log_error) ?;
228- let pair = self . get_pair ( key. as_ref ( ) ) . await ?;
229- match pair {
230- Some ( p) => Ok ( i64:: from_le_bytes (
231- p. value . try_into ( ) . expect ( "incorrect length" ) ,
232- ) ) ,
233- None => Err ( Error :: Other (
234- "increment returned an empty value after patching, which indicates a bug"
235- . to_string ( ) ,
236- ) ) ,
241+ {
242+ Err ( e) => {
243+ if e. as_http_error ( )
244+ . map ( |e| e. status ( ) == 404 )
245+ . unwrap_or ( false )
246+ {
247+ let counter = Counter {
248+ id : key. clone ( ) ,
249+ value : delta,
250+ store_id : self . store_id . clone ( ) ,
251+ } ;
252+ if let Err ( e) = self . client . create_document ( counter) . is_upsert ( false ) . await {
253+ if e. as_http_error ( )
254+ . map ( |e| e. status ( ) )
255+ . unwrap_or ( azure_core:: StatusCode :: Continue )
256+ == 409
257+ {
258+ // Conflict trying to create counter, retry increment
259+ self . increment ( key, delta) . await ?;
260+ } else {
261+ return Err ( log_error ( e) ) ;
262+ }
263+ }
264+ Ok ( delta)
265+ } else {
266+ Err ( log_error ( e) )
267+ }
268+ }
269+ Ok ( _) => self
270+ . get_entity :: < Counter > ( key. as_ref ( ) )
271+ . await ?
272+ . map ( |c| c. value )
273+ . ok_or ( Error :: Other (
274+ "increment returned an empty value after patching, which indicates a bug"
275+ . to_string ( ) ,
276+ ) ) ,
237277 }
238278 }
239279
@@ -353,15 +393,18 @@ impl Cas for CompareAndSwap {
353393}
354394
355395impl AzureCosmosStore {
356- async fn get_pair ( & self , key : & str ) -> Result < Option < Pair > , Error > {
396+ async fn get_entity < F > ( & self , key : & str ) -> Result < Option < F > , Error >
397+ where
398+ F : CosmosEntity + Send + Sync + serde:: de:: DeserializeOwned + Clone ,
399+ {
357400 let query = self
358401 . client
359402 . query_documents ( Query :: new ( self . get_query ( key) ) )
360403 . query_cross_partition ( true )
361404 . max_item_count ( 1 ) ;
362405
363406 // There can be no duplicated keys, so we create the stream and only take the first result.
364- let mut stream = query. into_stream :: < Pair > ( ) ;
407+ let mut stream = query. into_stream :: < F > ( ) ;
365408 let Some ( res) = stream. next ( ) . await else {
366409 return Ok ( None ) ;
367410 } ;
@@ -379,10 +422,10 @@ impl AzureCosmosStore {
379422 . query_cross_partition ( true ) ;
380423 let mut res = Vec :: new ( ) ;
381424
382- let mut stream = query. into_stream :: < Pair > ( ) ;
425+ let mut stream = query. into_stream :: < Key > ( ) ;
383426 while let Some ( resp) = stream. next ( ) . await {
384427 let resp = resp. map_err ( log_error) ?;
385- res. extend ( resp. results . into_iter ( ) . map ( |( pair , _) | pair . id ) ) ;
428+ res. extend ( resp. results . into_iter ( ) . map ( |( key , _) | key . id ) ) ;
386429 }
387430
388431 Ok ( res)
@@ -435,6 +478,7 @@ fn append_store_id_condition(
435478 }
436479}
437480
481+ // Pair structure for key value operations
438482#[ derive( Serialize , Deserialize , Clone , Debug ) ]
439483pub struct Pair {
440484 pub id : String ,
@@ -450,3 +494,36 @@ impl CosmosEntity for Pair {
450494 self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
451495 }
452496}
497+
498+ // Counter structure for increment operations
499+ #[ derive( Serialize , Deserialize , Clone , Debug ) ]
500+ pub struct Counter {
501+ pub id : String ,
502+ pub value : i64 ,
503+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
504+ pub store_id : Option < String > ,
505+ }
506+
507+ impl CosmosEntity for Counter {
508+ type Entity = String ;
509+
510+ fn partition_key ( & self ) -> Self :: Entity {
511+ self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
512+ }
513+ }
514+
515+ // Key structure for operations with generic value types
516+ #[ derive( Serialize , Deserialize , Clone , Debug ) ]
517+ pub struct Key {
518+ pub id : String ,
519+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
520+ pub store_id : Option < String > ,
521+ }
522+
523+ impl CosmosEntity for Key {
524+ type Entity = String ;
525+
526+ fn partition_key ( & self ) -> Self :: Entity {
527+ self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
528+ }
529+ }
0 commit comments