@@ -13,7 +13,7 @@ use std::sync::{Arc, Mutex};
13
13
14
14
pub struct KeyValueAzureCosmos {
15
15
client : CollectionClient ,
16
- app_id : String ,
16
+ app_id : Option < String > ,
17
17
}
18
18
19
19
/// Azure Cosmos Key / Value runtime config literal options for authentication
@@ -72,7 +72,7 @@ impl KeyValueAzureCosmos {
72
72
database : String ,
73
73
container : String ,
74
74
auth_options : KeyValueAzureCosmosAuthOptions ,
75
- app_id : String ,
75
+ app_id : Option < String > ,
76
76
) -> Result < Self > {
77
77
let token = match auth_options {
78
78
KeyValueAzureCosmosAuthOptions :: RuntimeConfigValues ( config) => {
@@ -97,7 +97,7 @@ impl StoreManager for KeyValueAzureCosmos {
97
97
async fn get ( & self , name : & str ) -> Result < Arc < dyn Store > , Error > {
98
98
Ok ( Arc :: new ( AzureCosmosStore {
99
99
client : self . client . clone ( ) ,
100
- partition_key : format ! ( "{}/{}" , self . app_id , name ) ,
100
+ partition_key : self . app_id . as_ref ( ) . map ( |i| format ! ( "{i }/{name}" ) ) ,
101
101
} ) )
102
102
}
103
103
@@ -117,7 +117,10 @@ impl StoreManager for KeyValueAzureCosmos {
117
117
#[ derive( Clone ) ]
118
118
struct AzureCosmosStore {
119
119
client : CollectionClient ,
120
- partition_key : String ,
120
+ /// An optional partition key to use for all operations.
121
+ ///
122
+ /// If the partition key is not set, the store will use `/id` as the partition key.
123
+ partition_key : Option < String > ,
121
124
}
122
125
123
126
#[ async_trait]
@@ -161,15 +164,7 @@ impl Store for AzureCosmosStore {
161
164
}
162
165
163
166
async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
164
- let in_clause: String = keys
165
- . into_iter ( )
166
- . map ( |k| format ! ( "'{}'" , k) )
167
- . collect :: < Vec < String > > ( )
168
- . join ( ", " ) ;
169
- let stmt = Query :: new ( format ! (
170
- "SELECT * FROM c WHERE c.id IN ({}) AND partition_key='{}'" ,
171
- in_clause, self . partition_key
172
- ) ) ;
167
+ let stmt = Query :: new ( self . get_in_query ( keys) ) ;
173
168
let query = self
174
169
. client
175
170
. query_documents ( stmt)
@@ -243,7 +238,19 @@ struct CompareAndSwap {
243
238
client : CollectionClient ,
244
239
bucket_rep : u32 ,
245
240
etag : Mutex < Option < String > > ,
246
- partition_key : String ,
241
+ partition_key : Option < String > ,
242
+ }
243
+
244
+ impl CompareAndSwap {
245
+ fn get_query ( & self ) -> String {
246
+ let mut query = format ! ( "SELECT * FROM c WHERE c.id='{}'" , self . key) ;
247
+ self . append_partition_key ( & mut query) ;
248
+ query
249
+ }
250
+
251
+ fn append_partition_key ( & self , query : & mut String ) {
252
+ append_partition_key_condition ( query, self . partition_key . as_deref ( ) ) ;
253
+ }
247
254
}
248
255
249
256
#[ async_trait]
@@ -253,10 +260,7 @@ impl Cas for CompareAndSwap {
253
260
async fn current ( & self ) -> Result < Option < Vec < u8 > > , Error > {
254
261
let mut stream = self
255
262
. client
256
- . query_documents ( Query :: new ( format ! (
257
- "SELECT * FROM c WHERE c.id='{}' and c.partition_key='{}'" ,
258
- self . key, self . partition_key
259
- ) ) )
263
+ . query_documents ( Query :: new ( self . get_query ( ) ) )
260
264
. query_cross_partition ( true )
261
265
. max_item_count ( 1 )
262
266
. into_stream :: < Pair > ( ) ;
@@ -287,7 +291,11 @@ impl Cas for CompareAndSwap {
287
291
/// `swap` updates the value for the key using the etag saved in the `current` function for
288
292
/// optimistic concurrency.
289
293
async fn swap ( & self , value : Vec < u8 > ) -> Result < ( ) , SwapError > {
290
- let pk = PartitionKey :: from ( & self . partition_key ) ;
294
+ let pk = PartitionKey :: from (
295
+ self . partition_key
296
+ . as_deref ( )
297
+ . unwrap_or_else ( || self . key . as_str ( ) ) ,
298
+ ) ;
291
299
let pair = Pair {
292
300
id : self . key . clone ( ) ,
293
301
value,
@@ -334,10 +342,7 @@ impl AzureCosmosStore {
334
342
async fn get_pair ( & self , key : & str ) -> Result < Option < Pair > , Error > {
335
343
let query = self
336
344
. client
337
- . query_documents ( Query :: new ( format ! (
338
- "SELECT * FROM c WHERE c.id='{}' AND c.partition_key='{}'" ,
339
- key, self . partition_key
340
- ) ) )
345
+ . query_documents ( Query :: new ( self . get_query ( key) ) )
341
346
. query_cross_partition ( true )
342
347
. max_item_count ( 1 ) ;
343
348
@@ -356,7 +361,7 @@ impl AzureCosmosStore {
356
361
async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
357
362
let query = self
358
363
. client
359
- . query_documents ( Query :: new ( "SELECT * FROM c" . to_string ( ) ) )
364
+ . query_documents ( Query :: new ( self . get_keys_query ( ) ) )
360
365
. query_cross_partition ( true ) ;
361
366
let mut res = Vec :: new ( ) ;
362
367
@@ -368,19 +373,59 @@ impl AzureCosmosStore {
368
373
369
374
Ok ( res)
370
375
}
376
+
377
+ fn get_query ( & self , key : & str ) -> String {
378
+ let mut query = format ! ( "SELECT * FROM c WHERE c.id='{}'" , key) ;
379
+ self . append_partition_key ( & mut query) ;
380
+ query
381
+ }
382
+
383
+ fn get_keys_query ( & self ) -> String {
384
+ let mut query = "SELECT * FROM c" . to_owned ( ) ;
385
+ self . append_partition_key ( & mut query) ;
386
+ query
387
+ }
388
+
389
+ fn get_in_query ( & self , keys : Vec < String > ) -> String {
390
+ let in_clause: String = keys
391
+ . into_iter ( )
392
+ . map ( |k| format ! ( "'{}'" , k) )
393
+ . collect :: < Vec < String > > ( )
394
+ . join ( ", " ) ;
395
+
396
+ let mut query = format ! ( "SELECT * FROM c WHERE c.id IN ({})" , in_clause) ;
397
+ self . append_partition_key ( & mut query) ;
398
+ query
399
+ }
400
+
401
+ fn append_partition_key ( & self , query : & mut String ) {
402
+ append_partition_key_condition ( query, self . partition_key . as_deref ( ) ) ;
403
+ }
404
+ }
405
+
406
+ /// Appends an option partition key condition to the query.
407
+ fn append_partition_key_condition ( query : & mut String , partition_key : Option < & str > ) {
408
+ if let Some ( pk) = partition_key {
409
+ query. push_str ( " AND c.partition_key='" ) ;
410
+ query. push_str ( pk) ;
411
+ query. push ( '\'' )
412
+ }
371
413
}
372
414
373
415
#[ derive( Serialize , Deserialize , Clone , Debug ) ]
374
416
pub struct Pair {
375
417
pub id : String ,
376
418
pub value : Vec < u8 > ,
377
- pub partition_key : String ,
419
+ #[ serde( skip_serializing_if = "Option::is_none" ) ]
420
+ pub partition_key : Option < String > ,
378
421
}
379
422
380
423
impl CosmosEntity for Pair {
381
424
type Entity = String ;
382
425
383
426
fn partition_key ( & self ) -> Self :: Entity {
384
- self . partition_key . clone ( )
427
+ self . partition_key
428
+ . clone ( )
429
+ . unwrap_or_else ( || self . id . clone ( ) )
385
430
}
386
431
}
0 commit comments