Skip to content

Commit 50561ae

Browse files
committed
refactor: let mls_init() take ClientId directly
The FFI function already does this but we somehow didn't change the internal function.
1 parent 2904844 commit 50561ae

File tree

10 files changed

+75
-70
lines changed

10 files changed

+75
-70
lines changed

crypto-ffi/src/core_crypto_context/mls.rs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
use std::{sync::Arc, time::Duration};
22

33
use core_crypto::{
4-
Ciphersuite as CryptoCiphersuite, ClientIdentifier, CredentialFindFilters, MlsConversationConfiguration,
5-
RecursiveError, VerifiableGroupInfo, mls::conversation::Conversation as _,
6-
transaction_context::Error as TransactionError,
4+
Ciphersuite as CryptoCiphersuite, CredentialFindFilters, MlsConversationConfiguration, RecursiveError,
5+
VerifiableGroupInfo, mls::conversation::Conversation as _, transaction_context::Error as TransactionError,
76
};
87
use tls_codec::Deserialize as _;
98

@@ -49,10 +48,7 @@ impl CoreCryptoContext {
4948
pub async fn mls_init(&self, client_id: &Arc<ClientId>, transport: Arc<dyn MlsTransport>) -> CoreCryptoResult<()> {
5049
let transport = callback_shim(transport);
5150
self.inner
52-
.mls_init(
53-
ClientIdentifier::Basic(client_id.as_ref().as_ref().to_owned()),
54-
transport,
55-
)
51+
.mls_init(client_id.as_ref().as_ref().to_owned(), transport)
5652
.await?;
5753
Ok(())
5854
}

crypto/benches/utils/mls.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ use std::{
55

66
use async_lock::RwLock;
77
use core_crypto::{
8-
CertificateBundle, Ciphersuite, ClientId, ClientIdentifier, ConnectionType, ConversationId, CoreCrypto,
9-
Credential as CcCredential, CredentialFindFilters, CredentialRef, CredentialType, Database, DatabaseKey,
10-
HistorySecret, MlsCommitBundle, MlsConversationConfiguration, MlsCryptoProvider, MlsGroupInfoBundle, MlsTransport,
11-
MlsTransportData, MlsTransportResponse,
8+
CertificateBundle, Ciphersuite, ClientId, ConnectionType, ConversationId, CoreCrypto, Credential as CcCredential,
9+
CredentialFindFilters, CredentialRef, CredentialType, Database, DatabaseKey, HistorySecret, MlsCommitBundle,
10+
MlsConversationConfiguration, MlsCryptoProvider, MlsGroupInfoBundle, MlsTransport, MlsTransportData,
11+
MlsTransportResponse,
1212
};
1313
use criterion::BenchmarkId;
1414
use openmls::{
@@ -165,21 +165,20 @@ pub async fn new_central(
165165
} else {
166166
ConnectionType::Persistent(&path)
167167
};
168-
let client_id = ClientId::from(Alphanumeric.sample_string(&mut rand::thread_rng(), 10).into_bytes());
169-
let client_identifier = ClientIdentifier::from(client_id.clone());
168+
let session_id = ClientId::from(Alphanumeric.sample_string(&mut rand::thread_rng(), 10).into_bytes());
170169
let db = Database::open(connection_type, &DatabaseKey::generate()).await.unwrap();
171170

172171
let cc = CoreCrypto::new(db);
173172
let delivery_service = Arc::<CoreCryptoTransportSuccessProvider>::default();
174173
let tx = cc.new_transaction().await.unwrap();
175-
tx.mls_init(client_identifier, delivery_service.clone()).await.unwrap();
174+
tx.mls_init(session_id.clone(), delivery_service.clone()).await.unwrap();
176175
tx.finish().await.unwrap();
177176

178177
let ctx = cc.new_transaction().await.unwrap();
179178

180179
let credential = match certificate_bundle {
181180
Some(certificate_bundle) => CcCredential::x509(ciphersuite, certificate_bundle.to_owned()).unwrap(),
182-
None => CcCredential::basic(ciphersuite, client_id).unwrap(),
181+
None => CcCredential::basic(ciphersuite, session_id).unwrap(),
183182
};
184183
let credential_ref = ctx.add_credential(credential).await.unwrap();
185184

crypto/src/ephemeral.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ use obfuscate::{Obfuscate, Obfuscated};
2828
use openmls::prelude::KeyPackageSecretEncapsulation;
2929

3030
use crate::{
31-
Ciphersuite, ClientId, ClientIdRef, ClientIdentifier, CoreCrypto, CoreCryptoTransportNotImplementedProvider,
32-
Credential, Error, MlsError, RecursiveError, Result, Session,
31+
Ciphersuite, ClientId, ClientIdRef, CoreCrypto, CoreCryptoTransportNotImplementedProvider, Credential, Error,
32+
MlsError, RecursiveError, Result, Session,
3333
mls_provider::{DatabaseKey, MlsCryptoProvider},
3434
};
3535

@@ -66,10 +66,9 @@ impl Obfuscate for HistorySecret {
6666
/// This implementation lives here instead of there for organizational reasons.
6767
pub(crate) async fn generate_history_secret(ciphersuite: Ciphersuite) -> Result<HistorySecret> {
6868
// generate a new completely arbitrary client id
69-
let client_id = uuid::Uuid::new_v4();
70-
let client_id = format!("{HISTORY_CLIENT_ID_PREFIX}-{client_id}");
71-
let client_id = ClientId::from(client_id.into_bytes());
72-
let identifier = ClientIdentifier::Basic(client_id.clone());
69+
let session_id = uuid::Uuid::new_v4();
70+
let session_id = format!("{HISTORY_CLIENT_ID_PREFIX}-{session_id}");
71+
let session_id = ClientId::from(session_id.into_bytes());
7372

7473
let database = Database::open(ConnectionType::InMemory, &DatabaseKey::generate())
7574
.await
@@ -82,14 +81,14 @@ pub(crate) async fn generate_history_secret(ciphersuite: Ciphersuite) -> Result<
8281
.map_err(RecursiveError::transaction("creating new transaction"))?;
8382

8483
let transport = Arc::new(CoreCryptoTransportNotImplementedProvider::default());
85-
tx.mls_init(identifier, transport)
84+
tx.mls_init(session_id.clone(), transport)
8685
.await
8786
.map_err(RecursiveError::transaction("initializing ephemeral cc"))?;
8887
let session = tx
8988
.session()
9089
.await
9190
.map_err(RecursiveError::transaction("Getting mls session"))?;
92-
let credential = Credential::basic(ciphersuite, client_id.clone()).map_err(RecursiveError::mls_credential(
91+
let credential = Credential::basic(ciphersuite, session_id.clone()).map_err(RecursiveError::mls_credential(
9392
"generating basic credential for ephemeral client",
9493
))?;
9594
let credential_ref = tx
@@ -112,7 +111,10 @@ pub(crate) async fn generate_history_secret(ciphersuite: Ciphersuite) -> Result<
112111
// there
113112
let _ = tx.abort().await;
114113

115-
Ok(HistorySecret { client_id, key_package })
114+
Ok(HistorySecret {
115+
client_id: session_id,
116+
key_package,
117+
})
116118
}
117119

118120
pub(crate) fn is_history_client(client_id: impl Borrow<ClientIdRef>) -> bool {

crypto/src/mls/mod.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,20 @@ mod tests {
110110
x509_test_chain.register_with_central(&context).await;
111111

112112
// phase 2: init mls_client
113-
let client_id = ClientId::from("alice");
113+
let session_id = ClientId::from("alice");
114114
let identifier = match case.credential_type {
115-
CredentialType::Basic => ClientIdentifier::Basic(client_id.clone()),
115+
CredentialType::Basic => ClientIdentifier::Basic(session_id),
116116
CredentialType::X509 => {
117-
CertificateBundle::rand_identifier(&client_id, &[x509_test_chain.find_local_intermediate_ca()])
117+
CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()])
118118
}
119119
};
120+
let session_id = identifier
121+
.get_id()
122+
.expect("get session_id from identifier")
123+
.into_owned();
120124
context
121125
.mls_init(
122-
identifier.clone(),
126+
session_id.clone(),
123127
Arc::new(CoreCryptoTransportSuccessProvider::default()),
124128
)
125129
.await

crypto/src/proteus.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -637,15 +637,19 @@ mod tests {
637637
// proteus is initialized, prekeys can be generated
638638
assert!(transaction.proteus_new_prekey(1).await.is_ok());
639639
// 👇 and so a unique 'client_id' can be fetched from wire-server
640-
let client_id = ClientId::from("alice");
640+
let session_id = ClientId::from("alice");
641+
let transport = Arc::new(CoreCryptoTransportSuccessProvider::default());
641642
let identifier = match case.credential_type {
642-
CredentialType::Basic => ClientIdentifier::Basic(client_id),
643+
CredentialType::Basic => ClientIdentifier::Basic(session_id),
643644
CredentialType::X509 => {
644-
CertificateBundle::rand_identifier(&client_id, &[x509_test_chain.find_local_intermediate_ca()])
645+
CertificateBundle::rand_identifier(&session_id, &[x509_test_chain.find_local_intermediate_ca()])
645646
}
646647
};
647-
let transport = Arc::new(CoreCryptoTransportSuccessProvider::default());
648-
transaction.mls_init(identifier.clone(), transport).await.unwrap();
648+
let session_id = identifier
649+
.get_id()
650+
.expect("Getting session id from identifier")
651+
.into_owned();
652+
transaction.mls_init(session_id, transport).await.unwrap();
649653
let credential = Credential::from_identifier(&identifier, case.ciphersuite()).unwrap();
650654
let credential_ref = transaction.add_credential(credential).await.unwrap();
651655

crypto/src/test_utils/mod.rs

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,12 @@ impl SessionContext {
129129

130130
let transaction = core_crypto.new_transaction().await.unwrap();
131131

132+
let session_id = identifier
133+
.get_id()
134+
.map_err(RecursiveError::mls_client("getting client id"))?
135+
.into_owned();
132136
transaction
133-
.mls_init(identifier.clone(), context.transport.clone())
137+
.mls_init(session_id, context.transport.clone())
134138
.await
135139
.map_err(RecursiveError::transaction("mls init"))?;
136140

@@ -275,9 +279,9 @@ impl SessionContext {
275279
self.set_session(new_session).await;
276280
}
277281

278-
pub async fn reinit_session(&self, identifier: ClientIdentifier) {
282+
pub async fn reinit_session(&self, session_id: ClientId) {
279283
self.transaction
280-
.mls_init(identifier, self.mls_transport().await)
284+
.mls_init(session_id, self.mls_transport().await)
281285
.await
282286
.unwrap();
283287

@@ -323,28 +327,31 @@ impl SessionContext {
323327
) -> Result<()> {
324328
let user_uuid = uuid::Uuid::new_v4();
325329
let rnd_id = rand::random::<usize>();
326-
let client_id = format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated());
327-
let client_id = ClientId(client_id.into_bytes());
330+
let session_id = ClientId(format!("{}:{rnd_id:x}@members.wire.com", user_uuid.hyphenated()).into_bytes());
328331

329-
let credential;
330-
let identifier;
331-
match case.credential_type {
332+
let (session_id, credential) = match case.credential_type {
332333
CredentialType::Basic => {
333-
identifier = ClientIdentifier::Basic(client_id.clone());
334-
credential = Credential::basic(case.ciphersuite(), client_id).unwrap();
334+
let credential =
335+
Credential::basic(case.ciphersuite(), session_id.clone()).expect("creating basic credential ");
336+
337+
(session_id, credential)
335338
}
336339
CredentialType::X509 => {
337-
let signer = signer.expect("Missing intermediate CA").to_owned();
338-
let cert = CertificateBundle::rand(&client_id, &signer);
339-
identifier = ClientIdentifier::X509([(case.signature_scheme(), cert.clone())].into());
340-
credential = Credential::x509(case.ciphersuite(), cert).unwrap();
341-
}
342-
};
340+
let signer = signer.expect("Missing intermediate CA");
341+
let cert = CertificateBundle::rand(&session_id, signer);
342+
let session_id = cert.get_client_id().expect("Getting client id from certificate bundle");
343343

344-
self.reinit_session(identifier).await;
344+
let credential = Credential::x509(case.ciphersuite(), cert).expect("creating x509 credential");
345345

346-
self.transaction.add_credential(credential).await.unwrap();
346+
(session_id, credential)
347+
}
348+
};
347349

350+
self.reinit_session(session_id).await;
351+
self.transaction
352+
.add_credential(credential)
353+
.await
354+
.expect("adding credential");
348355
Ok(())
349356
}
350357
}

crypto/src/transaction_context/e2e_identity/mod.rs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,14 @@ mod init_certificates;
77
mod rotate;
88
mod stash;
99

10-
use std::{
11-
collections::{HashMap, HashSet},
12-
sync::Arc,
13-
};
10+
use std::{collections::HashSet, sync::Arc};
1411

1512
pub use error::{Error, Result};
1613
use wire_e2e_identity::prelude::x509::extract_crl_uris;
1714

1815
use super::TransactionContext;
1916
use crate::{
20-
CertificateBundle, Ciphersuite, ClientId, ClientIdentifier, Credential, CredentialRef, E2eiEnrollment,
21-
MlsTransport, RecursiveError,
17+
CertificateBundle, Ciphersuite, ClientId, Credential, CredentialRef, E2eiEnrollment, MlsTransport, RecursiveError,
2218
e2e_identity::NewCrlDistributionPoints,
2319
mls::credential::{crl::get_new_crl_distribution_points, x509::CertificatePrivateKey},
2420
};
@@ -106,9 +102,10 @@ impl TransactionContext {
106102
let credential_ref = credential.save(database).await.map_err(RecursiveError::mls_credential(
107103
"saving credential in e2ei_mls_init_only",
108104
))?;
109-
110-
let identifier = ClientIdentifier::X509(HashMap::from([(ciphersuite.signature_algorithm(), cert_bundle)]));
111-
self.mls_init(identifier, transport)
105+
let session_id = cert_bundle.get_client_id().map_err(RecursiveError::mls_credential(
106+
"Getting session id from certificate bundle",
107+
))?;
108+
self.mls_init(session_id, transport)
112109
.await
113110
.map_err(RecursiveError::transaction("initializing mls"))?;
114111
Ok((credential_ref, crl_new_distribution_points))

crypto/src/transaction_context/e2e_identity/rotate.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ mod tests {
430430
keystore.commit_transaction().await.unwrap();
431431
keystore.new_transaction().await.unwrap();
432432

433-
alice.reinit_session(alice.identifier.clone()).await;
433+
alice.reinit_session(alice.get_client_id().await).await;
434434

435435
let new_session = alice.session().await;
436436
// Verify that Alice has the same credentials

crypto/src/transaction_context/mod.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use wire_e2e_identity::pki_env::PkiEnvironment;
1414
#[cfg(feature = "proteus")]
1515
use crate::proteus::ProteusCentral;
1616
use crate::{
17-
ClientId, ClientIdentifier, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation,
18-
MlsError, MlsTransport, RecursiveError, Session,
17+
ClientId, CoreCrypto, CredentialFindFilters, CredentialRef, KeystoreError, MlsConversation, MlsError, MlsTransport,
18+
RecursiveError, Session,
1919
group_store::GroupStore,
2020
mls::{self, HasSessionAndCrypto},
2121
mls_provider::{Database, MlsCryptoProvider},
@@ -261,12 +261,8 @@ impl TransactionContext {
261261
}
262262

263263
/// Initializes the MLS client of [super::CoreCrypto].
264-
pub async fn mls_init(&self, identifier: ClientIdentifier, transport: Arc<dyn MlsTransport>) -> Result<()> {
264+
pub async fn mls_init(&self, session_id: ClientId, transport: Arc<dyn MlsTransport>) -> Result<()> {
265265
let database = self.database().await?;
266-
let client_id = identifier
267-
.get_id()
268-
.map_err(RecursiveError::mls_client("getting client id"))?
269-
.into_owned();
270266

271267
let pki_env_provider = self
272268
.pki_environment_option()
@@ -276,7 +272,7 @@ impl TransactionContext {
276272

277273
let crypto_provider = MlsCryptoProvider::new_with_pki_env(database, pki_env_provider);
278274
let database = self.database().await?;
279-
let session = Session::new(client_id.clone(), crypto_provider, database, transport);
275+
let session = Session::new(session_id.clone(), crypto_provider, database, transport);
280276
self.set_mls_session(session).await?;
281277

282278
Ok(())

interop/src/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ async fn run_mls_test(chrome_driver_addr: &std::net::SocketAddr, web_server: &st
160160
let transaction = cc.new_transaction().await?;
161161
let success_provider = Arc::new(MlsTransportSuccessProvider::default());
162162
transaction
163-
.mls_init(master_client_id.clone().into(), success_provider.clone())
163+
.mls_init(master_client_id.clone(), success_provider.clone())
164164
.await?;
165165
let credential_ref = transaction.add_credential(credential).await?;
166166
transaction

0 commit comments

Comments
 (0)