Skip to content
37 changes: 20 additions & 17 deletions crypto/src/mls/credential/credential_ref/find.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use core_crypto_keystore::{entities::StoredCredential, traits::FetchFromDatabase};
use core_crypto_keystore::{
entities::{CredentialFindFilters as KeystoreFindFilters, StoredCredential},
traits::FetchFromDatabase,
};
use openmls::prelude::Credential as MlsCredential;
use tls_codec::Deserialize as _;

Expand Down Expand Up @@ -63,31 +66,31 @@ impl CredentialRef {
} = find_filters;

let partial_credentials = database
.load_all::<StoredCredential>()
.search::<StoredCredential, _>(&KeystoreFindFilters {
public_key,
earliest_validity,
session_id: client_id.map(AsRef::as_ref),
ciphersuite: ciphersuite.map(Into::into),
..Default::default()
})
.await
.map_err(KeystoreError::wrap("finding all credentials"))?
.map_err(KeystoreError::wrap("searching for credentials"))?
.into_iter()
.filter(|stored| {
client_id.is_none_or(|client_id| client_id.as_ref() == stored.session_id)
&& earliest_validity.is_none_or(|earliest_validity| earliest_validity == stored.created_at)
&& ciphersuite.is_none_or(|ciphersuite| u16::from(ciphersuite) == stored.ciphersuite)
&& public_key.is_none_or(|public_key| public_key == stored.public_key)
.map(|stored| {
MlsCredential::tls_deserialize_exact(&stored.credential)
.map_err(Error::tls_deserialize("Credential"))
.map(|mls_credential| (mls_credential, stored))
})
.map(|stored| -> Result<_> {
let mls_credential = MlsCredential::tls_deserialize_exact(&stored.credential)
.map_err(Error::tls_deserialize("Credential"))?;
Ok((mls_credential, stored))
.filter(|maybe_credential| {
maybe_credential.as_ref().ok().is_none_or(|(mls_credential, _)| {
credential_type.is_none_or(|credential_type| credential_type == mls_credential.credential_type())
})
});

let mut out = Vec::new();
for partial in partial_credentials {
let (ref mls_credential, ref stored_credential) = partial?;

if credential_type.is_some_and(|credential_type| credential_type != mls_credential.credential_type()) {
// credential type did not match
continue;
}

if let Ok(r#type) = mls_credential.credential_type().try_into()
&& let Ok(ciphersuite) = stored_credential.ciphersuite.try_into()
{
Expand Down
28 changes: 27 additions & 1 deletion keystore/src/entities/mls.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use zeroize::{Zeroize, ZeroizeOnDrop};

use crate::{
CryptoKeystoreResult,
CryptoKeystoreResult, Sha256Hash,
traits::{EntityBase, EntityGetBorrowed as _, KeyType, OwnedKeyType, PrimaryKey, SearchableEntity as _},
};

Expand Down Expand Up @@ -207,6 +207,32 @@ impl StoredBufferedCommit {
}
}

/// This type exists so that we can efficiently search for credentials by a variety of metrics at the database level.
///
/// This includes some but not all of the fields in `core_crypto::CredentialFindFilters`: those that are actually stored
/// in the database, and do not require deserializing the `credential` field.
#[derive(Debug, Default, Clone, Copy, serde::Serialize)]
pub struct CredentialFindFilters<'a> {
/// Hash of public key to search for.
pub hash: Option<Sha256Hash>,
/// Public key to search for
pub public_key: Option<&'a [u8]>,
/// Session / Client id to search for
pub session_id: Option<&'a [u8]>,
/// Ciphersuite to search for
pub ciphersuite: Option<u16>,
/// unix timestamp (seconds) of point of earliest validity to search for
pub earliest_validity: Option<u64>,
}

impl<'a> KeyType for CredentialFindFilters<'a> {
fn bytes(&self) -> std::borrow::Cow<'_, [u8]> {
postcard::to_stdvec(self)
.expect("serializing these filters cannot fail")
.into()
}
}

/// Entity representing a persisted `Credential`
#[derive(core_crypto_macros::Debug, Clone, PartialEq, Eq, Zeroize, serde::Serialize, serde::Deserialize)]
#[zeroize(drop)]
Expand Down
110 changes: 98 additions & 12 deletions keystore/src/entities/platform/generic/mls/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ use std::{
};

use async_trait::async_trait;
use rusqlite::{OptionalExtension as _, Row, Transaction, params};
use rusqlite::{OptionalExtension as _, Row, Transaction, named_params};

use crate::{
CryptoKeystoreError, CryptoKeystoreResult, Sha256Hash,
connection::{DatabaseConnection, KeystoreDatabaseConnection, TransactionWrapper},
entities::{StoredCredential, count_helper, count_helper_tx, delete_helper},
traits::{BorrowPrimaryKey, Entity, EntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType, PrimaryKey},
entities::{CredentialFindFilters, StoredCredential, count_helper, count_helper_tx, delete_helper},
traits::{
BorrowPrimaryKey, Entity, EntityBase, EntityDatabaseMutation, EntityDeleteBorrowed, KeyType, PrimaryKey,
SearchableEntity,
},
};

impl StoredCredential {
Expand Down Expand Up @@ -154,16 +157,24 @@ impl<'a> EntityDatabaseMutation<'a> for StoredCredential {
created_at,
ciphersuite,
private_key
) VALUES (?, ?, ?, ?, datetime(?, 'unixepoch'), ?, ?)",
) VALUES (
:public_key_sha256,
:public_key,
:session_id,
:credential,
datetime(:created_at, 'unixepoch'),
:ciphersuite,
:private_key
)",
)?;
stmt.execute(params![
self.primary_key(),
self.public_key,
self.session_id,
self.credential,
self.created_at,
self.ciphersuite,
self.private_key,
stmt.execute(named_params![
":public_key_sha256": self.primary_key(),
":public_key": self.public_key,
":session_id": self.session_id,
":credential": self.credential,
":created_at": self.created_at,
":ciphersuite": self.ciphersuite,
":private_key": self.private_key,
])?;

Ok(())
Expand All @@ -177,3 +188,78 @@ impl<'a> EntityDatabaseMutation<'a> for StoredCredential {
delete_helper::<Self>(tx, "public_key_sha256", id).await
}
}

#[async_trait]
impl<'a> SearchableEntity<CredentialFindFilters<'a>> for StoredCredential {
async fn find_all_matching(
conn: &mut Self::ConnectionType,
filters: &CredentialFindFilters<'a>,
) -> CryptoKeystoreResult<Vec<Self>> {
// if we know something unique to this credential, use a more-efficient search
if let Some(hash) = filters.hash.or_else(|| filters.public_key.map(Sha256Hash::hash_from)) {
return Self::get(conn, &hash).await.map(|optional| {
optional
.into_iter()
.filter(|credential| credential.matches(filters))
.collect()
});
}

let CredentialFindFilters {
ciphersuite,
earliest_validity,
session_id,
..
} = filters;

let mut query = "SELECT
session_id,
credential,
unixepoch(created_at) AS created_at,
ciphersuite,
public_key,
private_key
FROM mls_credentials
WHERE (true OR :ciphersuite OR :created_at OR :session_id) "
.to_owned();

if ciphersuite.is_some() {
query.push_str("AND ciphersuite = :ciphersuite ");
}
if earliest_validity.is_some() {
query.push_str("AND unixepoch(created_at) = :created_at ");
}
if session_id.is_some() {
query.push_str("AND session_id = :session_id ");
}
if session_id.is_some() {
query.push_str("AND session_id = ?3 ");
}

let conn = conn.conn().await;
let mut stmt = conn.prepare(&query)?;
stmt.query_map(
named_params![":ciphersuite": ciphersuite, ":created_at": earliest_validity, ":session_id": session_id],
Self::from_row,
)?
.collect::<Result<_, _>>()
.map_err(Into::into)
}

fn matches(
&self,
CredentialFindFilters {
hash,
public_key,
ciphersuite,
earliest_validity,
session_id,
}: &CredentialFindFilters<'a>,
) -> bool {
hash.is_none_or(|hash| hash == self.primary_key())
&& public_key.is_none_or(|public_key| public_key == self.public_key)
&& session_id.is_none_or(|session_id| session_id == self.session_id)
&& ciphersuite.is_none_or(|ciphersuite| ciphersuite == self.ciphersuite)
&& earliest_validity.is_none_or(|earliest_validity| earliest_validity == self.created_at)
}
}
50 changes: 48 additions & 2 deletions keystore/src/entities/platform/wasm/mls/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use web_time::SystemTime;
use crate::{
CryptoKeystoreError, CryptoKeystoreResult, Sha256Hash,
connection::{KeystoreDatabaseConnection, TransactionWrapper},
entities::StoredCredential,
entities::{CredentialFindFilters, StoredCredential},
traits::{
DecryptData, Decryptable, Decrypting, EncryptData, Encrypting, Entity, EntityBase, EntityDatabaseMutation,
KeyType as _, PrimaryKey,
KeyType as _, PrimaryKey, SearchableEntity,
},
};

Expand Down Expand Up @@ -146,3 +146,49 @@ impl Decrypting<'static> for StoredCredentialDecrypt {
impl Decryptable<'static> for StoredCredential {
type DecryptableFrom = StoredCredentialDecrypt;
}

#[async_trait(?Send)]
impl<'a> SearchableEntity<CredentialFindFilters<'a>> for StoredCredential {
async fn find_all_matching(
conn: &mut Self::ConnectionType,
filters: &CredentialFindFilters<'a>,
) -> CryptoKeystoreResult<Vec<Self>> {
// if we know something unique to this credential, use a more-efficient search
if let Some(hash) = filters.hash.or_else(|| filters.public_key.map(Sha256Hash::hash_from)) {
return Self::get(conn, &hash).await.map(|optional| {
optional
.into_iter()
.filter(|credential| credential.matches(filters))
.collect()
});
}

// We intentionally are using a kind of dumb filtering method here.
// This is OK because we are intentionally not adding big fancy heavy compound indices
// for this type. But what this means is that we end up doing a linear scan here (just
// like Sqlite is doing under the hood).
//
// It's fine though; at least this lets us skip the cost of deserializing all the credentials
// which have been filtered out.
let mut credentials = Self::load_all(conn).await?;
credentials.retain(|credential| credential.matches(filters));
Ok(credentials)
}

fn matches(
&self,
CredentialFindFilters {
hash,
public_key,
ciphersuite,
earliest_validity,
session_id,
}: &CredentialFindFilters<'a>,
) -> bool {
hash.is_none_or(|hash| hash == self.primary_key())
&& public_key.is_none_or(|public_key| public_key == self.public_key)
&& session_id.is_none_or(|session_id| session_id == self.session_id)
&& ciphersuite.is_none_or(|ciphersuite| ciphersuite == self.ciphersuite)
&& earliest_validity.is_none_or(|earliest_validity| earliest_validity == self.created_at)
}
}
1 change: 1 addition & 0 deletions keystore/src/hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub(crate) fn sha256(data: &[u8]) -> String {
derive_more::AsRef,
derive_more::From,
derive_more::Into,
serde::Serialize,
)]
#[as_ref(forward)]
pub struct Sha256Hash([u8; 32]);
Expand Down
Loading
Loading