Skip to content

Commit 49a4b2a

Browse files
committed
Move parallelization to rust
This should minimize the overhead and add parallelization to web.
1 parent b17324d commit 49a4b2a

File tree

5 files changed

+115
-47
lines changed

5 files changed

+115
-47
lines changed

lib/findMy/decrypt_reports.dart

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,34 @@ import 'package:openhaystack_mobile/ffi/ffi.dart'
1212

1313
class DecryptReports {
1414
/// Decrypts a given [FindMyReport] with the given private key.
15-
static Future<List<FindMyLocationReport>> decryptReportChunk(List<FindMyReport> reportChunk, Uint8List privateKeyBytes) async {
15+
static Future<List<FindMyLocationReport>> decryptReports(List<FindMyReport> reports, Uint8List privateKeyBytes) async {
1616
final curveDomainParam = ECCurve_secp224r1();
1717

18-
final ephemeralKeyChunk = reportChunk.map((report) {
18+
final ephemeralKeys = reports.map((report) {
1919
final payloadData = report.payload;
2020
final ephemeralKeyBytes = payloadData.sublist(5, 62);
2121
return ephemeralKeyBytes;
2222
}).toList();
2323

24-
late final List<Uint8List> sharedKeyChunk;
24+
late final List<Uint8List> sharedKeys;
2525

2626
try {
2727
debugPrint("Trying native ECDH");
28-
final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeyChunk.expand((element) => element).toList());
28+
final ephemeralKeyBlob = Uint8List.fromList(ephemeralKeys.expand((element) => element).toList());
2929
final sharedKeyBlob = await api.ecdh(publicKeyBlob: ephemeralKeyBlob, privateKey: privateKeyBytes);
30-
final chunkSize = (sharedKeyBlob.length / ephemeralKeyChunk.length).ceil();
31-
sharedKeyChunk = [
32-
for (var i = 0; i < sharedKeyBlob.length; i += chunkSize)
33-
sharedKeyBlob.sublist(i, i + chunkSize < sharedKeyBlob.length ? i + chunkSize : sharedKeyBlob.length),
30+
final keySize = (sharedKeyBlob.length / ephemeralKeys.length).ceil();
31+
sharedKeys = [
32+
for (var i = 0; i < sharedKeyBlob.length; i += keySize)
33+
sharedKeyBlob.sublist(i, i + keySize < sharedKeyBlob.length ? i + keySize : sharedKeyBlob.length),
3434
];
3535
}
3636
catch (e) {
3737
debugPrint("Native ECDH failed: $e");
38-
debugPrint("Falling back to pure Dart ECDH!");
38+
debugPrint("Falling back to pure Dart ECDH on single thread!");
3939
final privateKey = ECPrivateKey(
4040
pc_utils.decodeBigIntWithSign(1, privateKeyBytes),
4141
curveDomainParam);
42-
sharedKeyChunk = ephemeralKeyChunk.map((ephemeralKey) {
42+
sharedKeys = ephemeralKeys.map((ephemeralKey) {
4343
final decodePoint = curveDomainParam.curve.decodePoint(ephemeralKey);
4444
final ephemeralPublicKey = ECPublicKey(decodePoint, curveDomainParam);
4545

@@ -48,8 +48,8 @@ class DecryptReports {
4848
}).toList();
4949
}
5050

51-
final decryptedLocationChunk = reportChunk.mapIndexed((index, report) {
52-
final derivedKey = _kdf(sharedKeyChunk[index], ephemeralKeyChunk[index]);
51+
final decryptedLocations = reports.mapIndexed((index, report) {
52+
final derivedKey = _kdf(sharedKeys[index], ephemeralKeys[index]);
5353
final payloadData = report.payload;
5454
_decodeTimeAndConfidence(payloadData, report);
5555
final encData = payloadData.sublist(62, 72);
@@ -59,7 +59,7 @@ class DecryptReports {
5959
return locationReport;
6060
}).toList();
6161

62-
return decryptedLocationChunk;
62+
return decryptedLocations;
6363
}
6464

6565
/// Decodes the unencrypted timestamp and confidence

lib/findMy/find_my_controller.dart

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import 'dart:collection';
22
import 'dart:convert';
33
import 'dart:isolate';
44
import 'dart:typed_data';
5-
import 'dart:io' as IO;
65

76
import 'package:flutter/foundation.dart';
87
import 'package:flutter_secure_storage/flutter_secure_storage.dart';
@@ -38,18 +37,8 @@ class FindMyController {
3837
FindMyKeyPair keyPair = args[0];
3938
String seemooEndpoint = args[1];
4039
final jsonReports = await ReportsFetcher.fetchLocationReports(keyPair.getHashedAdvertisementKey(), seemooEndpoint);
41-
final numChunks = kIsWeb ? 1 : IO.Platform.numberOfProcessors+1;
42-
final chunkSize = (jsonReports.length / numChunks).ceil();
43-
final chunks = [
44-
for (var i = 0; i < jsonReports.length; i += chunkSize)
45-
jsonReports.sublist(i, i + chunkSize < jsonReports.length ? i + chunkSize : jsonReports.length),
46-
];
47-
final decryptedLocations = await Future.wait(chunks.map((jsonChunk) async {
48-
final decryptedChunk = await compute(_decryptChunk, [jsonChunk, keyPair, keyPair.privateKeyBase64!]);
49-
return decryptedChunk;
50-
}));
51-
final results = decryptedLocations.expand((element) => element).toList();
52-
return results;
40+
final decryptedLocations = await _decryptReports(jsonReports, keyPair, keyPair.privateKeyBase64!);
41+
return decryptedLocations;
5342
}
5443

5544
/// Loads the private key from the local cache or secure storage and adds it
@@ -77,12 +66,8 @@ class FindMyController {
7766

7867
/// Decrypts the encrypted reports with the given list of [FindMyKeyPair] and private key.
7968
/// Returns the list of decrypted reports as a list of [FindMyLocationReport].
80-
static Future<List<FindMyLocationReport>> _decryptChunk(List<dynamic> args) async {
81-
List<dynamic> jsonChunk = args[0];
82-
FindMyKeyPair keyPair = args[1];
83-
String privateKey = args[2];
84-
85-
final reportChunk = jsonChunk.map((jsonReport) {
69+
static Future<List<FindMyLocationReport>> _decryptReports(List<dynamic> jsonRerportList, FindMyKeyPair keyPair, String privateKey) async {
70+
final reportChunk = jsonRerportList.map((jsonReport) {
8671
assert (jsonReport["id"]! == keyPair.getHashedAdvertisementKey(),
8772
"Returned FindMyReport hashed key != requested hashed key");
8873

@@ -98,7 +83,7 @@ class FindMyController {
9883
return report;
9984
}).toList();
10085

101-
final decryptedReports = await DecryptReports.decryptReportChunk(reportChunk, base64Decode(privateKey));
86+
final decryptedReports = await DecryptReports.decryptReports(reportChunk, base64Decode(privateKey));
10287

10388
return decryptedReports;
10489
}

native/Cargo.lock

Lines changed: 81 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ crate-type = ["staticlib", "cdylib", "rlib"]
1010
flutter_rust_bridge = "^1.77.0"
1111
p224 = "^0.13.2"
1212
getrandom = "^0.2.9"
13+
rayon = "1.7.0"
1314

1415
[features]
1516
default = ["p224/ecdh", "getrandom/js"]

native/src/api.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
use p224::{SecretKey, PublicKey, ecdh::diffie_hellman};
2+
use rayon::prelude::*;
3+
use std::sync::{Arc, Mutex};
24

35
const PRIVATE_LEN : usize = 28;
46
const PUBLIC_LEN : usize = 57;
57

68
pub fn ecdh(public_key_blob : Vec<u8>, private_key : Vec<u8>) -> Vec<u8> {
79
let num_keys = public_key_blob.len() / PUBLIC_LEN;
8-
let mut vec_shared_secret = vec![0u8; num_keys*PRIVATE_LEN];
10+
let vec_shared_secret = Arc::new(Mutex::new(vec![0u8; num_keys*PRIVATE_LEN]));
911

1012
let private_key = SecretKey::from_slice(&private_key).unwrap();
1113
let secret_scalar = private_key.to_nonzero_scalar();
12-
13-
let mut i = 0;
14-
let mut j = 0;
1514

16-
for _i in 0..num_keys {
17-
let public_key = PublicKey::from_sec1_bytes(&public_key_blob[i..i+PUBLIC_LEN]).unwrap();
15+
(0..num_keys).into_par_iter().for_each(|i| {
16+
let start = i * PUBLIC_LEN;
17+
let end = start + PUBLIC_LEN;
18+
let public_key = PublicKey::from_sec1_bytes(&public_key_blob[start..end]).unwrap();
1819
let public_affine = public_key.as_affine();
19-
20-
let shared_secret = diffie_hellman(secret_scalar, public_affine);
20+
21+
let shared_secret = diffie_hellman(secret_scalar, public_affine);
2122
let shared_secret_ref = shared_secret.raw_secret_bytes().as_ref();
2223

24+
let start = i * PRIVATE_LEN;
25+
let end = start + PRIVATE_LEN;
2326

24-
vec_shared_secret[j..j+PRIVATE_LEN].copy_from_slice(shared_secret_ref);
27+
let mut vec_shared_secret = vec_shared_secret.lock().unwrap();
28+
vec_shared_secret[start..end].copy_from_slice(shared_secret_ref);
29+
});
2530

26-
i += PUBLIC_LEN;
27-
j += PRIVATE_LEN;
28-
}
29-
30-
return vec_shared_secret;
31+
Arc::try_unwrap(vec_shared_secret).unwrap().into_inner().unwrap()
3132
}

0 commit comments

Comments
 (0)