Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions rln/src/pm_tree_adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use zerokit_utils::{
},
pm_tree::{
pmtree,
pmtree::{tree::Key, Database, Hasher, PmtreeErrorKind},
pmtree::{tree::Key, Database, Hasher, PmtreeErrorKind, TreeErrorKind},
Config, Mode, SledDB,
},
};
Expand Down Expand Up @@ -239,9 +239,15 @@ impl ZerokitMerkleTree for PmTree {
Err(_) => pmtree::MerkleTree::new(depth, config.0)?,
};

let capacity = 1usize.checked_shl(depth as u32).ok_or({
ZerokitMerkleTreeError::PmtreeErrorKind(PmtreeErrorKind::TreeError(
TreeErrorKind::IndexOutOfBounds,
))
})?;

Ok(PmTree {
tree,
cached_leaves_indices: vec![0; 1 << depth],
cached_leaves_indices: vec![0; capacity],
metadata: Vec::new(),
})
}
Expand Down
34 changes: 18 additions & 16 deletions rln/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,9 @@ pub fn bytes_le_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize), UtilsError>
}
let len = usize::try_from(u64::from_le_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len {
if len > input.len() - VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len),
actual: input.len(),
});
}
Expand All @@ -266,9 +266,9 @@ pub fn bytes_be_to_vec_u8(input: &[u8]) -> Result<(Vec<u8>, usize), UtilsError>
}
let len = usize::try_from(u64::from_be_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len {
if len > input.len() - VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len),
actual: input.len(),
});
}
Expand All @@ -289,9 +289,9 @@ pub fn bytes_le_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize), UtilsError>
let len = usize::try_from(u64::from_le_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
let el_size = FR_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len * el_size {
if len > (input.len() - VEC_LEN_BYTE_SIZE) / el_size {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len * el_size,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len.saturating_mul(el_size)),
actual: input.len(),
});
}
Expand All @@ -318,9 +318,9 @@ pub fn bytes_be_to_vec_fr(input: &[u8]) -> Result<(Vec<Fr>, usize), UtilsError>
let len = usize::try_from(u64::from_be_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
let el_size = FR_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len * el_size {
if len > (input.len() - VEC_LEN_BYTE_SIZE) / el_size {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len * el_size,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len.saturating_mul(el_size)),
actual: input.len(),
});
}
Expand All @@ -347,9 +347,10 @@ pub fn bytes_le_to_vec_usize(input: &[u8]) -> Result<Vec<usize>, UtilsError> {
if nof_elem == 0 {
Ok(vec![])
} else {
if input.len() < VEC_LEN_BYTE_SIZE + nof_elem * VEC_LEN_BYTE_SIZE {
if nof_elem > (input.len() - VEC_LEN_BYTE_SIZE) / VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + nof_elem * VEC_LEN_BYTE_SIZE,
expected: VEC_LEN_BYTE_SIZE
.saturating_add(nof_elem.saturating_mul(VEC_LEN_BYTE_SIZE)),
actual: input.len(),
});
}
Expand Down Expand Up @@ -377,9 +378,10 @@ pub fn bytes_be_to_vec_usize(input: &[u8]) -> Result<Vec<usize>, UtilsError> {
if nof_elem == 0 {
Ok(vec![])
} else {
if input.len() < VEC_LEN_BYTE_SIZE + nof_elem * VEC_LEN_BYTE_SIZE {
if nof_elem > (input.len() - VEC_LEN_BYTE_SIZE) / VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + nof_elem * VEC_LEN_BYTE_SIZE,
expected: VEC_LEN_BYTE_SIZE
.saturating_add(nof_elem.saturating_mul(VEC_LEN_BYTE_SIZE)),
actual: input.len(),
});
}
Expand All @@ -406,9 +408,9 @@ pub fn bytes_le_to_vec_bool(input: &[u8]) -> Result<(Vec<bool>, usize), UtilsErr
}
let len = usize::try_from(u64::from_le_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len {
if len > input.len() - VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len),
actual: input.len(),
});
}
Expand All @@ -431,9 +433,9 @@ pub fn bytes_be_to_vec_bool(input: &[u8]) -> Result<(Vec<bool>, usize), UtilsErr
}
let len = usize::try_from(u64::from_be_bytes(input[0..VEC_LEN_BYTE_SIZE].try_into()?))?;
read += VEC_LEN_BYTE_SIZE;
if input.len() < VEC_LEN_BYTE_SIZE + len {
if len > input.len() - VEC_LEN_BYTE_SIZE {
return Err(UtilsError::InsufficientData {
expected: VEC_LEN_BYTE_SIZE + len,
expected: VEC_LEN_BYTE_SIZE.saturating_add(len),
actual: input.len(),
});
}
Expand Down
21 changes: 21 additions & 0 deletions rln/tests/pm_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,27 @@ mod test {
));
}

#[test]
fn test_pmtree_depth_shift_overflow() {
let depth = usize::BITS as usize;
let result = PmTree::new(depth, Fr::zero(), temp_config());
assert!(matches!(
result,
Err(ZerokitMerkleTreeError::PmtreeErrorKind(_))
));
}

#[test]
fn test_pmtree_override_range_min_index_underflow() {
let mut tree = PmTree::new(TEST_DEPTH, Fr::zero(), temp_config()).unwrap();
let result =
tree.override_range(0, vec![Fr::from(1)].into_iter(), vec![5usize].into_iter());
assert!(matches!(
result,
Err(ZerokitMerkleTreeError::InvalidIndices)
));
}

#[test]
fn test_pmtree_basic_operations() {
let mut tree = PmTree::default(TEST_DEPTH).unwrap();
Expand Down
24 changes: 24 additions & 0 deletions rln/tests/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,30 @@ mod test {
assert_eq!(computed_root2, root2);
}

#[test]
fn test_bytes_le_to_rln_proof_short() {
let bytes = vec![0u8; COMPRESS_PROOF_SIZE - 1];
assert!(bytes_le_to_rln_proof(&bytes).is_err());
}

#[test]
fn test_bytes_be_to_rln_proof_short() {
let bytes = vec![0u8; COMPRESS_PROOF_SIZE - 1];
assert!(bytes_be_to_rln_proof(&bytes).is_err());
}

#[test]
fn test_bytes_le_to_rln_proof_empty() {
let bytes = vec![];
assert!(bytes_le_to_rln_proof(&bytes).is_err());
}

#[test]
fn test_bytes_be_to_rln_proof_empty() {
let bytes = vec![];
assert!(bytes_be_to_rln_proof(&bytes).is_err());
}

#[test]
fn test_rln_witness_to_bigint_json_fields() {
// Test with default witness
Expand Down
19 changes: 19 additions & 0 deletions rln/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,25 @@ mod test {
assert!(bytes_be_to_vec_u8(&valid_u8_data_be).is_ok());
}

#[test]
fn test_length_prefix_overflow() {
let mut overflow_u8 = vec![0u8; 8];
overflow_u8[..8].copy_from_slice(&normalize_usize_le(usize::MAX));
assert!(bytes_le_to_vec_u8(&overflow_u8).is_err());

let mut overflow_u8_be = vec![0u8; 8];
overflow_u8_be[..8].copy_from_slice(&normalize_usize_be(usize::MAX));
assert!(bytes_be_to_vec_u8(&overflow_u8_be).is_err());

let mut overflow_fr = vec![0u8; 8];
overflow_fr[..8].copy_from_slice(&normalize_usize_le(usize::MAX));
assert!(bytes_le_to_vec_fr(&overflow_fr).is_err());

let mut overflow_fr_be = vec![0u8; 8];
overflow_fr_be[..8].copy_from_slice(&normalize_usize_be(usize::MAX));
assert!(bytes_be_to_vec_fr(&overflow_fr_be).is_err());
}

#[test]
fn test_empty_vectors() {
// Test empty vector serialization/deserialization
Expand Down
2 changes: 2 additions & 0 deletions utils/src/merkle_tree/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ pub enum ZerokitMerkleTreeError {
InvalidSubTreeIndex,
#[error("Start level is != from end level")]
InvalidStartAndEndLevel,
#[error("Tree depth exceeds maximum allowed (must be < {})", usize::BITS)]
InvalidDepth,
#[error("set_range got too many leaves")]
TooManySet,
#[error("Unknown error while computing merkle proof")]
Expand Down
4 changes: 4 additions & 0 deletions utils/src/merkle_tree/full_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ where
default_leaf: FrOf<Self::Hasher>,
_config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError> {
if depth >= usize::BITS as usize {
return Err(ZerokitMerkleTreeError::InvalidDepth);
}

// Compute cache node values, leaf to root
let mut cached_nodes: Vec<H::Fr> = Vec::with_capacity(depth + 1);
cached_nodes.push(default_leaf);
Expand Down
4 changes: 4 additions & 0 deletions utils/src/merkle_tree/optimal_merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ where
default_leaf: H::Fr,
_config: Self::Config,
) -> Result<Self, ZerokitMerkleTreeError> {
if depth >= usize::BITS as usize {
return Err(ZerokitMerkleTreeError::InvalidDepth);
}

// Compute cache node values, leaf to root
let mut cached_nodes: Vec<H::Fr> = Vec::with_capacity(depth + 1);
cached_nodes.push(default_leaf);
Expand Down
28 changes: 28 additions & 0 deletions utils/src/poseidon/poseidon_constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,31 @@ pub fn find_poseidon_ark_and_mds<F: PrimeField>(

(ark, mds)
}

#[cfg(test)]
mod test {
use ark_bn254::Fr;
use num_traits::Zero;

use super::*;

#[test]
fn test_find_poseidon_ark_and_mds_bn254_regression_no_inverse_panic() {
let result = std::panic::catch_unwind(|| {
// Parameters match the hardcoded BN254 Poseidon setup used by current tests.
find_poseidon_ark_and_mds::<Fr>(1, 0, 254, 2, 8, 56, 0)
});

assert!(
result.is_ok(),
"find_poseidon_ark_and_mds unexpectedly panicked (possible MDS inverse invariant break)"
);

let (ark, mds) = result.unwrap();
assert_eq!(ark.len(), (8 + 56) * 2);
assert_eq!(mds.len(), 2);
assert_eq!(mds[0].len(), 2);
assert_eq!(mds[1].len(), 2);
assert_ne!(mds[0][0], Fr::zero());
}
}
49 changes: 49 additions & 0 deletions utils/tests/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,55 @@ mod test {
assert_ne!(root_before, root_after);
}

#[test]
fn test_full_merkle_tree_new_depth_shift_overflow() {
let depth = usize::BITS as usize;
let result =
FullMerkleTree::<Keccak256>::new(depth, TestFr([0; 32]), FullMerkleConfig::default());
assert!(result.is_err());
}

#[test]
fn test_optimal_merkle_tree_new_depth_shift_overflow() {
let depth = usize::BITS as usize;
let result = OptimalMerkleTree::<Keccak256>::new(
depth,
TestFr([0; 32]),
OptimalMerkleConfig::default(),
);
assert!(result.is_err());
}

#[test]
fn test_full_merkle_tree_set_range_start_overflow() {
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
let result = tree_full.set_range(usize::MAX, std::iter::once(TestFr::from(1u32)));
assert!(result.is_err());
}

#[test]
fn test_optimal_merkle_tree_set_range_start_overflow() {
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
let result = tree_opt.set_range(usize::MAX, std::iter::once(TestFr::from(1u32)));
assert!(result.is_err());
}

#[test]
fn test_full_merkle_tree_override_range_min_index_underflow() {
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
let result =
tree_full.override_range(1, std::iter::once(TestFr::from(1u32)), [5usize].into_iter());
assert!(result.is_err());
}

#[test]
fn test_optimal_merkle_tree_override_range_min_index_underflow() {
let mut tree_opt = default_optimal_merkle_tree(DEFAULT_DEPTH);
let result =
tree_opt.override_range(1, std::iter::once(TestFr::from(1u32)), [5usize].into_iter());
assert!(result.is_err());
}

#[test]
fn test_update_next() {
let mut tree_full = default_full_merkle_tree(DEFAULT_DEPTH);
Expand Down
Loading