Skip to content

Commit b9ae38c

Browse files
committed
First working version
Signed-off-by: Ryan Levick <[email protected]>
1 parent 26ff330 commit b9ae38c

File tree

3 files changed

+56
-40
lines changed

3 files changed

+56
-40
lines changed

crates/key-value-azure/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@ use store::{
77
};
88

99
/// A key-value store that uses Azure Cosmos as the backend.
10-
#[derive(Default)]
1110
pub struct AzureKeyValueStore {
12-
_priv: (),
11+
app_id: String,
1312
}
1413

1514
impl AzureKeyValueStore {
1615
/// Creates a new `AzureKeyValueStore`.
17-
pub fn new() -> Self {
18-
Self::default()
16+
pub fn new(app_id: String) -> Self {
17+
Self { app_id }
1918
}
2019
}
2120

@@ -55,6 +54,7 @@ impl MakeKeyValueStore for AzureKeyValueStore {
5554
runtime_config.database,
5655
runtime_config.container,
5756
auth_options,
57+
self.app_id.clone(),
5858
)
5959
}
6060
}

crates/key-value-azure/src/store.rs

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use std::sync::{Arc, Mutex};
1313

1414
pub struct KeyValueAzureCosmos {
1515
client: CollectionClient,
16+
app_id: String,
1617
}
1718

1819
/// Azure Cosmos Key / Value runtime config literal options for authentication
@@ -71,6 +72,7 @@ impl KeyValueAzureCosmos {
7172
database: String,
7273
container: String,
7374
auth_options: KeyValueAzureCosmosAuthOptions,
75+
app_id: String,
7476
) -> Result<Self> {
7577
let token = match auth_options {
7678
KeyValueAzureCosmosAuthOptions::RuntimeConfigValues(config) => {
@@ -86,15 +88,16 @@ impl KeyValueAzureCosmos {
8688
let database_client = cosmos_client.database_client(database);
8789
let client = database_client.collection_client(container);
8890

89-
Ok(Self { client })
91+
Ok(Self { client, app_id })
9092
}
9193
}
9294

9395
#[async_trait]
9496
impl StoreManager for KeyValueAzureCosmos {
95-
async fn get(&self, _name: &str) -> Result<Arc<dyn Store>, Error> {
97+
async fn get(&self, name: &str) -> Result<Arc<dyn Store>, Error> {
9698
Ok(Arc::new(AzureCosmosStore {
9799
client: self.client.clone(),
100+
partition_key: format!("{}/{}", self.app_id, name),
98101
}))
99102
}
100103

@@ -114,13 +117,7 @@ impl StoreManager for KeyValueAzureCosmos {
114117
#[derive(Clone)]
115118
struct AzureCosmosStore {
116119
client: CollectionClient,
117-
}
118-
119-
struct CompareAndSwap {
120-
key: String,
121-
client: CollectionClient,
122-
bucket_rep: u32,
123-
etag: Mutex<Option<String>>,
120+
partition_key: String,
124121
}
125122

126123
#[async_trait]
@@ -134,6 +131,7 @@ impl Store for AzureCosmosStore {
134131
let pair = Pair {
135132
id: key.to_string(),
136133
value: value.to_vec(),
134+
partition_key: self.partition_key.clone(),
137135
};
138136
self.client
139137
.create_document(pair)
@@ -145,7 +143,10 @@ impl Store for AzureCosmosStore {
145143

146144
async fn delete(&self, key: &str) -> Result<(), Error> {
147145
if self.exists(key).await? {
148-
let document_client = self.client.document_client(key, &key).map_err(log_error)?;
146+
let document_client = self
147+
.client
148+
.document_client(key, &self.partition_key)
149+
.map_err(log_error)?;
149150
document_client.delete_document().await.map_err(log_error)?;
150151
}
151152
Ok(())
@@ -165,7 +166,10 @@ impl Store for AzureCosmosStore {
165166
.map(|k| format!("'{}'", k))
166167
.collect::<Vec<String>>()
167168
.join(", ");
168-
let stmt = Query::new(format!("SELECT * FROM c WHERE c.id IN ({})", in_clause));
169+
let stmt = Query::new(format!(
170+
"SELECT * FROM c WHERE c.id IN ({}) AND partition_key='{}'",
171+
in_clause, self.partition_key
172+
));
169173
let query = self
170174
.client
171175
.query_documents(stmt)
@@ -175,9 +179,11 @@ impl Store for AzureCosmosStore {
175179
let mut stream = query.into_stream::<Pair>();
176180
while let Some(resp) = stream.next().await {
177181
let resp = resp.map_err(log_error)?;
178-
for (pair, _) in resp.results {
179-
res.push((pair.id, Some(pair.value)));
180-
}
182+
res.extend(
183+
resp.results
184+
.into_iter()
185+
.map(|(pair, _)| (pair.id, Some(pair.value))),
186+
);
181187
}
182188
Ok(res)
183189
}
@@ -200,7 +206,7 @@ impl Store for AzureCosmosStore {
200206
let operations = vec![Operation::incr("/value", delta).map_err(log_error)?];
201207
let _ = self
202208
.client
203-
.document_client(key.clone(), &key.as_str())
209+
.document_client(key.clone(), &self.partition_key)
204210
.map_err(log_error)?
205211
.patch_document(operations)
206212
.await
@@ -227,10 +233,19 @@ impl Store for AzureCosmosStore {
227233
client: self.client.clone(),
228234
etag: Mutex::new(None),
229235
bucket_rep,
236+
partition_key: self.partition_key.clone(),
230237
}))
231238
}
232239
}
233240

241+
struct CompareAndSwap {
242+
key: String,
243+
client: CollectionClient,
244+
bucket_rep: u32,
245+
etag: Mutex<Option<String>>,
246+
partition_key: String,
247+
}
248+
234249
#[async_trait]
235250
impl Cas for CompareAndSwap {
236251
/// `current` will fetch the current value for the key and store the etag for the record. The
@@ -239,8 +254,8 @@ impl Cas for CompareAndSwap {
239254
let mut stream = self
240255
.client
241256
.query_documents(Query::new(format!(
242-
"SELECT * FROM c WHERE c.id='{}'",
243-
self.key
257+
"SELECT * FROM c WHERE c.id='{}' and c.partition_key='{}'",
258+
self.key, self.partition_key
244259
)))
245260
.query_cross_partition(true)
246261
.max_item_count(1)
@@ -272,10 +287,11 @@ impl Cas for CompareAndSwap {
272287
/// `swap` updates the value for the key using the etag saved in the `current` function for
273288
/// optimistic concurrency.
274289
async fn swap(&self, value: Vec<u8>) -> Result<(), SwapError> {
275-
let pk = PartitionKey::from(&self.key);
290+
let pk = PartitionKey::from(&self.partition_key);
276291
let pair = Pair {
277292
id: self.key.clone(),
278293
value,
294+
partition_key: self.partition_key.clone(),
279295
};
280296

281297
let doc_client = self
@@ -318,23 +334,23 @@ impl AzureCosmosStore {
318334
async fn get_pair(&self, key: &str) -> Result<Option<Pair>, Error> {
319335
let query = self
320336
.client
321-
.query_documents(Query::new(format!("SELECT * FROM c WHERE c.id='{}'", key)))
337+
.query_documents(Query::new(format!(
338+
"SELECT * FROM c WHERE c.id='{}' AND c.partition_key='{}'",
339+
key, self.partition_key
340+
)))
322341
.query_cross_partition(true)
323342
.max_item_count(1);
324343

325344
// There can be no duplicated keys, so we create the stream and only take the first result.
326345
let mut stream = query.into_stream::<Pair>();
327-
let res = stream.next().await;
328-
match res {
329-
Some(r) => {
330-
let r = r.map_err(log_error)?;
331-
match r.results.first().cloned() {
332-
Some((p, _)) => Ok(Some(p)),
333-
None => Ok(None),
334-
}
335-
}
336-
None => Ok(None),
337-
}
346+
let Some(res) = stream.next().await else {
347+
return Ok(None);
348+
};
349+
Ok(res
350+
.map_err(log_error)?
351+
.results
352+
.first()
353+
.map(|(p, _)| p.clone()))
338354
}
339355

340356
async fn get_keys(&self) -> Result<Vec<String>, Error> {
@@ -347,9 +363,7 @@ impl AzureCosmosStore {
347363
let mut stream = query.into_stream::<Pair>();
348364
while let Some(resp) = stream.next().await {
349365
let resp = resp.map_err(log_error)?;
350-
for (pair, _) in resp.results {
351-
res.push(pair.id);
352-
}
366+
res.extend(resp.results.into_iter().map(|(pair, _)| pair.id));
353367
}
354368

355369
Ok(res)
@@ -358,15 +372,15 @@ impl AzureCosmosStore {
358372

359373
#[derive(Serialize, Deserialize, Clone, Debug)]
360374
pub struct Pair {
361-
// In Azure CosmosDB, the default partition key is "/id", and this implementation assumes that partition ID is not changed.
362375
pub id: String,
363376
pub value: Vec<u8>,
377+
pub partition_key: String,
364378
}
365379

366380
impl CosmosEntity for Pair {
367381
type Entity = String;
368382

369383
fn partition_key(&self) -> Self::Entity {
370-
self.id.clone()
384+
self.partition_key.clone()
371385
}
372386
}

crates/runtime-config/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,9 @@ pub fn key_value_config_resolver(
403403
.register_store_type(spin_key_value_redis::RedisKeyValueStore::new())
404404
.unwrap();
405405
key_value
406-
.register_store_type(spin_key_value_azure::AzureKeyValueStore::new())
406+
.register_store_type(spin_key_value_azure::AzureKeyValueStore::new(
407+
"MY_APP".to_owned(),
408+
))
407409
.unwrap();
408410
key_value
409411
.register_store_type(spin_key_value_aws::AwsDynamoKeyValueStore::new())

0 commit comments

Comments
 (0)