@@ -2,6 +2,7 @@ use anyhow::{Context, Result};
22use azure_data_cosmos:: {
33 prelude:: {
44 AuthorizationToken , CollectionClient , CosmosClient , CosmosClientBuilder , Operation , Query ,
5+ TentativeWritesAllowance ,
56 } ,
67 CosmosEntity ,
78} ;
@@ -145,7 +146,7 @@ struct AzureCosmosStore {
145146#[ async_trait]
146147impl Store for AzureCosmosStore {
147148 async fn get ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Error > {
148- let pair = self . get_pair ( key) . await ?;
149+ let pair = self . get_pair :: < Pair > ( key) . await ?;
149150 Ok ( pair. map ( |p| p. value ) )
150151 }
151152
@@ -158,24 +159,35 @@ impl Store for AzureCosmosStore {
158159 self . client
159160 . create_document ( pair)
160161 . is_upsert ( true )
162+ . allow_tentative_writes ( TentativeWritesAllowance :: Allow )
161163 . await
162164 . map_err ( log_error) ?;
163165 Ok ( ( ) )
164166 }
165167
166168 async fn delete ( & self , key : & str ) -> Result < ( ) , Error > {
167- if self . exists ( key) . await ? {
168- let document_client = self
169- . client
170- . document_client ( key, & self . store_id )
171- . map_err ( log_error) ?;
172- document_client. delete_document ( ) . await . map_err ( log_error) ?;
169+ let document_client = self
170+ . client
171+ . document_client ( key, & self . store_id . clone ( ) . unwrap_or ( key. to_string ( ) ) )
172+ . map_err ( log_error) ?;
173+ if let Err ( e) = document_client
174+ . delete_document ( )
175+ . allow_tentative_writes ( TentativeWritesAllowance :: Allow )
176+ . await
177+ {
178+ if e. as_http_error ( )
179+ . map ( |e| e. status ( ) )
180+ . unwrap_or ( azure_core:: StatusCode :: Continue )
181+ != 404
182+ {
183+ return Err ( log_error ( e) ) ;
184+ }
173185 }
174186 Ok ( ( ) )
175187 }
176188
177189 async fn exists ( & self , key : & str ) -> Result < bool , Error > {
178- Ok ( self . get_pair ( key) . await ?. is_some ( ) )
190+ Ok ( self . get_pair :: < Key > ( key) . await ?. is_some ( ) )
179191 }
180192
181193 async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
@@ -216,24 +228,67 @@ impl Store for AzureCosmosStore {
216228 Ok ( ( ) )
217229 }
218230
231+ // WARNING: this function only works on the primary region because the
232+ // `azure_data_cosmos-0.21.0` release of the Azure cosmos SDK does not
233+ // support setting allow_tentative_writes on patch requests. The initial
234+ // value for the item must be set through this interfaces, as this sets the
235+ // number value if it does not exist. If the value was previously set using
236+ // the `set` interface, this will fail due to a type mismatch.
237+ //
238+ // TODO: The function should parse the new value from the return response
239+ // rather than sending an additional new request. However, the current SDK
240+ // version does not support this.
219241 async fn increment ( & self , key : String , delta : i64 ) -> Result < i64 , Error > {
220242 let operations = vec ! [ Operation :: incr( "/value" , delta) . map_err( log_error) ?] ;
221- let _ = self
243+ match self
222244 . client
223- . document_client ( key. clone ( ) , & self . store_id )
245+ . document_client ( & key, & self . store_id . clone ( ) . unwrap_or ( key . to_string ( ) ) )
224246 . map_err ( log_error) ?
225247 . patch_document ( operations)
226248 . await
227- . map_err ( log_error) ?;
228- let pair = self . get_pair ( key. as_ref ( ) ) . await ?;
229- match pair {
230- Some ( p) => Ok ( i64:: from_le_bytes (
231- p. value . try_into ( ) . expect ( "incorrect length" ) ,
232- ) ) ,
233- None => Err ( Error :: Other (
234- "increment returned an empty value after patching, which indicates a bug"
235- . to_string ( ) ,
236- ) ) ,
249+ {
250+ Err ( e) => {
251+ if e. as_http_error ( )
252+ . map ( |e| e. status ( ) )
253+ . unwrap_or ( azure_core:: StatusCode :: Continue )
254+ == 404
255+ {
256+ let counter = Counter {
257+ id : key. clone ( ) ,
258+ value : delta,
259+ store_id : self . store_id . clone ( ) ,
260+ } ;
261+ if let Err ( e) = self
262+ . client
263+ . create_document ( counter)
264+ . is_upsert ( false )
265+ . allow_tentative_writes ( TentativeWritesAllowance :: Allow )
266+ . await
267+ {
268+ if e. as_http_error ( )
269+ . map ( |e| e. status ( ) )
270+ . unwrap_or ( azure_core:: StatusCode :: Continue )
271+ == 409
272+ {
273+ // Conflict trying to create counter, retry increment
274+ self . increment ( key, delta) . await ?;
275+ } else {
276+ return Err ( log_error ( e) ) ;
277+ }
278+ }
279+ Ok ( delta)
280+ } else {
281+ Err ( log_error ( e) )
282+ }
283+ }
284+ Ok ( _) => self
285+ . get_pair :: < Counter > ( key. as_ref ( ) )
286+ . await ?
287+ . map ( |c| c. value )
288+ . ok_or ( Error :: Other (
289+ "increment returned an empty value after patching, which indicates a bug"
290+ . to_string ( ) ,
291+ ) ) ,
237292 }
238293 }
239294
@@ -274,8 +329,9 @@ impl CompareAndSwap {
274329
275330#[ async_trait]
276331impl Cas for CompareAndSwap {
277- /// `current` will fetch the current value for the key and store the etag for the record. The
278- /// etag will be used to perform and optimistic concurrency update using the `if-match` header.
332+ /// `current` will fetch the current value for the key and store the etag
333+ /// for the record. The etag will be used to perform and optimistic
334+ /// concurrency update using the `if-match` header.
279335 async fn current ( & self ) -> Result < Option < Vec < u8 > > , Error > {
280336 let mut stream = self
281337 . client
@@ -307,8 +363,8 @@ impl Cas for CompareAndSwap {
307363 }
308364 }
309365
310- /// `swap` updates the value for the key using the etag saved in the `current` function for
311- /// optimistic concurrency.
366+ /// `swap` updates the value for the key using the etag saved in the
367+ /// `current` function for optimistic concurrency.
312368 async fn swap ( & self , value : Vec < u8 > ) -> Result < ( ) , SwapError > {
313369 let pair = Pair {
314370 id : self . key . clone ( ) ,
@@ -327,15 +383,18 @@ impl Cas for CompareAndSwap {
327383 // attempt to replace the document if the etag matches
328384 doc_client
329385 . replace_document ( pair)
386+ . allow_tentative_writes ( TentativeWritesAllowance :: Allow )
330387 . if_match_condition ( azure_core:: request_options:: IfMatchCondition :: Match ( etag) )
331388 . await
332389 . map_err ( |e| SwapError :: CasFailed ( format ! ( "{e:?}" ) ) )
333390 . map ( drop)
334391 }
335392 None => {
336- // if we have no etag, then we assume the document does not yet exist and must insert; no upserts.
393+ // if we have no etag, then we assume the document does not yet
394+ // exist and must insert; no upserts.
337395 self . client
338396 . create_document ( pair)
397+ . allow_tentative_writes ( TentativeWritesAllowance :: Allow )
339398 . await
340399 . map_err ( |e| SwapError :: CasFailed ( format ! ( "{e:?}" ) ) )
341400 . map ( drop)
@@ -353,15 +412,18 @@ impl Cas for CompareAndSwap {
353412}
354413
355414impl AzureCosmosStore {
356- async fn get_pair ( & self , key : & str ) -> Result < Option < Pair > , Error > {
415+ async fn get_pair < F > ( & self , key : & str ) -> Result < Option < F > , Error >
416+ where
417+ F : CosmosEntity + Send + Sync + serde:: de:: DeserializeOwned + Clone ,
418+ {
357419 let query = self
358420 . client
359421 . query_documents ( Query :: new ( self . get_query ( key) ) )
360422 . query_cross_partition ( true )
361423 . max_item_count ( 1 ) ;
362424
363425 // There can be no duplicated keys, so we create the stream and only take the first result.
364- let mut stream = query. into_stream :: < Pair > ( ) ;
426+ let mut stream = query. into_stream :: < F > ( ) ;
365427 let Some ( res) = stream. next ( ) . await else {
366428 return Ok ( None ) ;
367429 } ;
@@ -379,10 +441,10 @@ impl AzureCosmosStore {
379441 . query_cross_partition ( true ) ;
380442 let mut res = Vec :: new ( ) ;
381443
382- let mut stream = query. into_stream :: < Pair > ( ) ;
444+ let mut stream = query. into_stream :: < Key > ( ) ;
383445 while let Some ( resp) = stream. next ( ) . await {
384446 let resp = resp. map_err ( log_error) ?;
385- res. extend ( resp. results . into_iter ( ) . map ( |( pair , _) | pair . id ) ) ;
447+ res. extend ( resp. results . into_iter ( ) . map ( |( key , _) | key . id ) ) ;
386448 }
387449
388450 Ok ( res)
@@ -435,6 +497,7 @@ fn append_store_id_condition(
435497 }
436498}
437499
500+ // Pair structure for key value operations
438501#[ derive( Serialize , Deserialize , Clone , Debug ) ]
439502pub struct Pair {
440503 pub id : String ,
@@ -450,3 +513,36 @@ impl CosmosEntity for Pair {
450513 self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
451514 }
452515}
516+
517+ // Counter structure for increment operations
518+ #[ derive( Serialize , Deserialize , Clone , Debug ) ]
519+ pub struct Counter {
520+ pub id : String ,
521+ pub value : i64 ,
522+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
523+ pub store_id : Option < String > ,
524+ }
525+
526+ impl CosmosEntity for Counter {
527+ type Entity = String ;
528+
529+ fn partition_key ( & self ) -> Self :: Entity {
530+ self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
531+ }
532+ }
533+
534+ // Key structure for operations with generic value types
535+ #[ derive( Serialize , Deserialize , Clone , Debug ) ]
536+ pub struct Key {
537+ pub id : String ,
538+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
539+ pub store_id : Option < String > ,
540+ }
541+
542+ impl CosmosEntity for Key {
543+ type Entity = String ;
544+
545+ fn partition_key ( & self ) -> Self :: Entity {
546+ self . store_id . clone ( ) . unwrap_or_else ( || self . id . clone ( ) )
547+ }
548+ }
0 commit comments