Skip to content

Commit dab4f4b

Browse files
apollo_propeller: add reed-solomon erasure coding (#11047)
1 parent 7dca152 commit dab4f4b

File tree

6 files changed

+182
-0
lines changed

6 files changed

+182
-0
lines changed

Cargo.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,7 @@ quote = "1.0.26"
315315
rand = "0.8.5"
316316
rand_chacha = "0.3.1"
317317
rand_distr = "0.4.3"
318+
reed-solomon-simd = "3.1.0"
318319
regex = "1.10.4"
319320
replace_with = "0.1.7"
320321
reqwest = "0.12"

crates/apollo_propeller/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ description = "Implementation of the Propeller algorithm for the Starknet sequen
1010
testing = []
1111

1212
[dependencies]
13+
reed-solomon-simd.workspace = true
1314
sha2.workspace = true
1415

1516
[dev-dependencies]

crates/apollo_propeller/src/lib.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
pub mod merkle;
22
#[cfg(test)]
33
mod merkle_test;
4+
// TODO(AndrewL): Consider renaming this to `erasure_coding` or `error_correction_code`.
5+
pub mod reed_solomon;
6+
#[cfg(test)]
7+
mod reed_solomon_test;
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use reed_solomon_simd::{ReedSolomonDecoder, ReedSolomonEncoder};
2+
3+
// TODO(AndrewL): Consider combining this with `generate_coding_shards`.
4+
// TODO(AndrewL): Consider adding custom error type and using it here.
5+
pub fn split_data_into_shards(message: Vec<u8>, num_data_shards: usize) -> Option<Vec<Vec<u8>>> {
6+
if !message.len().is_multiple_of(num_data_shards) {
7+
return None;
8+
}
9+
let shard_size = message.len() / num_data_shards;
10+
Some(message.chunks_exact(shard_size).map(|chunk| chunk.to_vec()).collect())
11+
}
12+
13+
/// Generate coding shards using Reed-Solomon encoding.
14+
// TODO(AndrewL): Consider adding custom error type and using it here.
15+
pub fn generate_coding_shards(
16+
data_shards: &[Vec<u8>],
17+
num_coding_shards: usize,
18+
) -> Result<Vec<Vec<u8>>, String> {
19+
if num_coding_shards == 0 {
20+
// ReedSolomonEncoder does not support 0 coding shards
21+
return Ok(Vec::new());
22+
}
23+
24+
let num_data_shards = data_shards.len();
25+
// TODO(AndrewL): Consider accepting a shard size as an argument.
26+
let shard_size = data_shards.first().ok_or("No data shards".to_string())?.len();
27+
28+
let mut encoder = ReedSolomonEncoder::new(num_data_shards, num_coding_shards, shard_size)
29+
.map_err(|e| format!("Failed to create Reed-Solomon encoder: {}", e))?;
30+
31+
for shard in data_shards.iter().take(num_data_shards) {
32+
encoder
33+
.add_original_shard(shard)
34+
.map_err(|e| format!("Failed to add data shard: {}", e))?;
35+
}
36+
37+
let result = encoder.encode().map_err(|e| format!("Failed to encode: {}", e))?;
38+
39+
let coding_shards = result.recovery_iter().map(|shard| shard.to_vec()).collect();
40+
41+
Ok(coding_shards)
42+
}
43+
44+
/// Reconstruct the original message from available shards using Reed-Solomon error correction.
45+
// TODO(AndrewL): Consider adding custom error type and using it here.
46+
// TODO(AndrewL): Rename this to `reconstruct_data_shards`.
47+
pub fn reconstruct_message_from_shards(
48+
// TODO(AndrewL): Change this to a HashMap<usize, Vec<u8>>.
49+
shards: &[(usize, Vec<u8>)],
50+
num_data_shards: usize,
51+
num_coding_shards: usize,
52+
) -> Result<Vec<Vec<u8>>, String> {
53+
if num_coding_shards == 0 {
54+
return Ok(shards.iter().map(|(_, s)| s.to_vec()).collect());
55+
}
56+
// TODO(AndrewL): Consider accepting a shard size as an argument.
57+
let shard_size = shards.first().ok_or("No shards".to_string())?.1.len();
58+
59+
let mut decoder = ReedSolomonDecoder::new(num_data_shards, num_coding_shards, shard_size)
60+
.map_err(|e| format!("Failed to create Reed-Solomon decoder: {}", e))?;
61+
62+
for (index, shard_data) in shards {
63+
if *index < num_data_shards {
64+
decoder
65+
.add_original_shard(*index, shard_data)
66+
.map_err(|e| format!("Failed to add original shard: {}", e))?;
67+
} else {
68+
decoder
69+
.add_recovery_shard(index - num_data_shards, shard_data)
70+
.map_err(|e| format!("Failed to add coding shard: {}", e))?;
71+
}
72+
}
73+
74+
let result = decoder.decode().map_err(|e| format!("Failed to decode: {}", e))?;
75+
76+
let mut shard_map = std::collections::HashMap::new();
77+
for (index, shard) in shards {
78+
shard_map.insert(index, shard);
79+
}
80+
81+
let mut data_shards = Vec::with_capacity(num_data_shards);
82+
for index in 0..num_data_shards {
83+
if let Some(shard_shard) = shard_map.get(&index) {
84+
data_shards.push(shard_shard.to_vec());
85+
} else if let Some(restored_data) = result.restored_original(index) {
86+
data_shards.push(restored_data.to_vec());
87+
} else {
88+
return Err(format!(
89+
"Missing data shard at index {} and no restored data available",
90+
index
91+
));
92+
}
93+
}
94+
95+
Ok(data_shards)
96+
}
97+
98+
pub fn combine_data_shards(data_shards: Vec<Vec<u8>>) -> Vec<u8> {
99+
data_shards.iter().flatten().copied().collect()
100+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
use rstest::rstest;
2+
3+
use crate::reed_solomon::{generate_coding_shards, reconstruct_message_from_shards};
4+
5+
#[test]
6+
fn test_empty_generate_coding_shards() {
7+
let data_shards = vec![vec![0, 1], vec![2, 3], vec![4, 5]];
8+
let num_coding_shards = 0;
9+
let coding_shards = generate_coding_shards(&data_shards, num_coding_shards).unwrap();
10+
assert!(coding_shards.is_empty());
11+
}
12+
13+
#[rstest]
14+
#[case(3, 2, 1, 8)]
15+
#[case(4, 2, 2, 16)]
16+
#[case(5, 3, 3, 8)]
17+
#[case(4, 4, 4, 32)]
18+
fn test_reed_solomon_with_lost_shards(
19+
#[case] num_data_shards: u8,
20+
#[case] num_coding_shards: usize,
21+
#[case] num_lost_shards: usize,
22+
#[case] shard_size: usize,
23+
) {
24+
let data_shards: Vec<Vec<u8>> = (0..num_data_shards).map(|i| vec![i; shard_size]).collect();
25+
let num_data_shards: usize = num_data_shards.into();
26+
let original_data = data_shards.clone();
27+
28+
let coding_shards = generate_coding_shards(&data_shards, num_coding_shards).unwrap();
29+
assert_eq!(coding_shards.len(), num_coding_shards);
30+
31+
let all_shards: Vec<(usize, Vec<u8>)> = data_shards
32+
.iter()
33+
.enumerate()
34+
.map(|(i, s)| (i, s.clone()))
35+
.chain(coding_shards.iter().enumerate().map(|(i, s)| (num_data_shards + i, s.clone())))
36+
.collect();
37+
38+
let available_shards: Vec<(usize, Vec<u8>)> = all_shards
39+
.into_iter()
40+
.enumerate()
41+
.filter(|(i, _)| *i % 2 != 0 || *i >= num_lost_shards * 2)
42+
.map(|(_, shard)| shard)
43+
.collect();
44+
45+
assert!(
46+
available_shards.len() >= num_data_shards,
47+
"Not enough shards to reconstruct (have {}, need {})",
48+
available_shards.len(),
49+
num_data_shards
50+
);
51+
52+
let reconstructed_data =
53+
reconstruct_message_from_shards(&available_shards, num_data_shards, num_coding_shards)
54+
.unwrap();
55+
56+
assert_eq!(reconstructed_data, original_data, "Reconstructed data doesn't match original");
57+
}

0 commit comments

Comments
 (0)