diff --git a/Cargo.toml b/Cargo.toml index 4099d4a..34e927c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ ethereum_serde_utils = "0.8.0" ethereum_ssz = "0.9.0" serde = "1.0.0" serde_derive = "1.0.0" -typenum = "1.12.0" +typenum = { version = "1.12.0", features = ["const-generics"] } smallvec = "1.8.0" arbitrary = { version = "1.0", features = ["derive"], optional = true } itertools = "0.13.0" @@ -24,3 +24,7 @@ itertools = "0.13.0" [dev-dependencies] serde_json = "1.0.0" tree_hash_derive = "0.10.0" + +[patch.crates-io] +tree_hash = { path = "../tree_hash/tree_hash" } +tree_hash_derive = { path = "../tree_hash/tree_hash_derive" } diff --git a/src/fixed_vector.rs b/src/fixed_vector.rs index 53bf8ad..66cf75f 100644 --- a/src/fixed_vector.rs +++ b/src/fixed_vector.rs @@ -210,6 +210,14 @@ where } } +impl tree_hash::prototype::MerkleProof for FixedVector +where + T: tree_hash::TreeHash +{ + fn compute_proof_for_gindex(&self, gindex: usize) -> Result, tree_hash::prototype::Error>{ + crate::tree_hash::generate_proof_for_vec::(&self.vec, gindex) + } +} impl ssz::Encode for FixedVector where T: ssz::Encode, @@ -569,4 +577,74 @@ mod test { let result: Result, _> = serde_json::from_value(json); assert!(result.is_ok()); } + + #[test] + fn merkle_proof_basic() { + use tree_hash::prototype::MerkleProof; + use typenum::U4; + + let vec: FixedVector = FixedVector::new(vec![1, 2, 3, 4]).unwrap(); + + let proof = vec.compute_proof_for_gindex(1); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + assert_eq!(proof.len(), 0); + } + + let proof = vec.compute_proof_for_gindex(2); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + assert!(!proof.is_empty()); + } + + let proof = vec.compute_proof_for_gindex(0); + assert!(proof.is_err()); + } + + #[test] + fn merkle_proof_complex_types() { + use tree_hash::prototype::MerkleProof; + use typenum::U2; + + let a1 = A { a: 1, b: 2 }; + let a2 = A { a: 3, b: 4 }; + let vec: FixedVector = FixedVector::new(vec![a1, a2]).unwrap(); + + let proof = vec.compute_proof_for_gindex(2); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + assert!(!proof.is_empty()); + + for hash in proof { + assert_eq!(hash.len(), 32); + } + } + } + + #[test] + fn merkle_proof_tree_depth() { + use tree_hash::prototype::MerkleProof; + use typenum::U8; + + let vec: FixedVector = FixedVector::new(vec![1, 2, 3, 4, 5, 6, 7, 8]).unwrap(); + + let gindices = vec![1, 2, 3, 4, 5, 6, 7, 8, 15, 16]; + + for gindex in gindices { + let proof = vec.compute_proof_for_gindex(gindex); + + if gindex == 0 { + assert!(proof.is_err()); + } else if gindex == 1 { + if let Ok(proof) = proof { + assert_eq!(proof.len(), 0); + } + } else { + if let Ok(proof) = proof { + let expected_depth = 64 - gindex.leading_zeros() as usize - 1; + assert_eq!(proof.len(), expected_depth); + } + } + } + } } diff --git a/src/tree_hash.rs b/src/tree_hash.rs index 675d0ea..c37637d 100644 --- a/src/tree_hash.rs +++ b/src/tree_hash.rs @@ -1,5 +1,126 @@ +use crate::VariableList; +use tree_hash::prototype::{get_vector_item_position, vector_chunk_count, Resolve, VecIndex}; use tree_hash::{Hash256, MerkleHasher, TreeHash, TreeHashType}; -use typenum::Unsigned; +use typenum::{ + generic_const_mappings::{Const, ToUInt, U}, + Unsigned, +}; + +pub fn generate_proof_for_vec(vec: &[T], gindex: usize) -> Result, tree_hash::prototype::Error> +where + T: TreeHash, + N: Unsigned, +{ + let target_size = N::to_usize(); + + if gindex == 0 { + return Err(tree_hash::prototype::Error::Oops); + } + + if target_size == 0 { + return Ok(vec![]); + } + + generate_proof::(vec, gindex) +} + +fn generate_proof(vec: &[T], gindex: usize) -> Result, tree_hash::prototype::Error> +where + T: TreeHash, + N: Unsigned, +{ + let target_size = N::to_usize(); + let mut proof = Vec::new(); + let mut current_gindex = gindex; + + let (_, effective_size) = match T::tree_hash_type() { + TreeHashType::Basic => { + let chunk_count = (target_size + T::tree_hash_packing_factor() - 1) / T::tree_hash_packing_factor(); + let padded_count = chunk_count.next_power_of_two(); + (64 - padded_count.leading_zeros() as usize, padded_count) + } + _ => { + let padded_size = target_size.next_power_of_two(); + (64 - padded_size.leading_zeros() as usize, padded_size) + } + }; + + while current_gindex > 1 { + let is_right_child = current_gindex % 2 == 1; + let sibling_gindex = if is_right_child { + current_gindex - 1 + } else { + current_gindex + 1 + }; + + let sibling_hash = compute_node_hash_at_gindex::(vec, sibling_gindex, effective_size)?; + proof.push(sibling_hash); + + current_gindex /= 2; + } + + Ok(proof) +} + +fn compute_node_hash_at_gindex( + vec: &[T], + gindex: usize, + effective_size: usize +) -> Result +where + T: TreeHash, + N: Unsigned, +{ + let target_size = N::to_usize(); + + match T::tree_hash_type() { + TreeHashType::Basic => { + let chunk_count = (target_size + T::tree_hash_packing_factor() - 1) / T::tree_hash_packing_factor(); + + if gindex >= effective_size { + let chunk_index = gindex - effective_size; + if chunk_index < chunk_count { + let start_idx = chunk_index * T::tree_hash_packing_factor(); + let end_idx = std::cmp::min(start_idx + T::tree_hash_packing_factor(), vec.len()); + + let mut hasher = MerkleHasher::with_leaves(1); + for j in start_idx..end_idx { + hasher.write(&vec[j].tree_hash_packed_encoding()).map_err(|_| tree_hash::prototype::Error::Oops)?; + } + hasher.finish().map_err(|_| tree_hash::prototype::Error::Oops) + } else { + Ok(Hash256::new([0; 32])) + } + } else { + let left_child = gindex * 2; + let right_child = gindex * 2 + 1; + + let left_hash = compute_node_hash_at_gindex::(vec, left_child, effective_size)?; + let right_hash = compute_node_hash_at_gindex::(vec, right_child, effective_size)?; + + Ok(tree_hash::merkle_root(&[left_hash.as_slice(), right_hash.as_slice()].concat(), 0)) + } + } + _ => { + if gindex >= effective_size { + let leaf_index = gindex - effective_size; + if leaf_index < vec.len() { + Ok(vec[leaf_index].tree_hash_root()) + } else { + Ok(Hash256::new([0; 32])) + } + } else { + let left_child = gindex * 2; + let right_child = gindex * 2 + 1; + + let left_hash = compute_node_hash_at_gindex::(vec, left_child, effective_size)?; + let right_hash = compute_node_hash_at_gindex::(vec, right_child, effective_size)?; + + Ok(tree_hash::merkle_root(&[left_hash.as_slice(), right_hash.as_slice()].concat(), 0)) + } + } + } +} /// A helper function providing common functionality between the `TreeHash` implementations for /// `FixedVector` and `VariableList`. @@ -39,3 +160,181 @@ where } } } + +impl Resolve> for VariableList> +where + T: TreeHash, + Const: ToUInt, +{ + type Output = T; + + fn gindex(parent_index: usize) -> usize { + // Base index is 2 due to length mixin. + let base_index = 2; + + // Chunk count takes into account packing of leaves. + let chunk_count = vector_chunk_count::(N); + + let pos = get_vector_item_position::(I); + + // Gindex of Nth element of this vector. + parent_index * base_index * chunk_count.next_power_of_two() + pos + } +} + +#[cfg(test)] +mod test { + use super::*; + use tree_hash::prototype::{Field, Path, Resolve, VecIndex}; + use tree_hash_derive::TreeHash; + use typenum::{U10, U5}; + + // Some example structs. + #[derive(TreeHash)] + struct Nested3 { + x3: Nested2, + y3: Nested1, + } + + #[derive(TreeHash)] + struct Nested2 { + x2: Nested1, + y2: Nested1, + } + + #[derive(TreeHash)] + struct Nested1 { + x1: u64, + y1: VariableList, + } + + // Fields of Nested3 (these would be generated). + struct FieldX3; + struct FieldY3; + + impl Field for FieldX3 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 0; + } + + impl Field for FieldY3 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 1; + } + + // Fields of Nested2 (generated). + struct FieldX2; + struct FieldY2; + + impl Field for FieldX2 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 0; + } + + impl Field for FieldY2 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 1; + } + + // Fields of Nested1 (generated). + struct FieldX1; + struct FieldY1; + + impl Field for FieldX1 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 0; + } + + impl Field for FieldY1 { + const NUM_FIELDS: usize = 2; + const INDEX: usize = 1; + } + + // Implementations of Resolve (generated). + impl Resolve for Nested3 { + type Output = Nested2; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + impl Resolve for Nested3 { + type Output = Nested1; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + impl Resolve for Nested2 { + type Output = Nested1; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + impl Resolve for Nested2 { + type Output = Nested1; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + impl Resolve for Nested1 { + type Output = u64; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + impl Resolve for Nested1 { + type Output = VariableList; + + fn gindex(parent_index: usize) -> usize { + parent_index * ::NUM_FIELDS.next_power_of_two() + + ::INDEX + } + } + + // x3.x2.x1 + type FieldX3X2X1 = Path>; + + // x3.x2.x1 + type FieldX3X2Y1 = Path>; + + // x3.y2.y1.5 + type FieldX3Y2Y1I5 = Path>>>; + + // 0.x3.y2.y1.5 + type FieldI0X3Y2Y1I5 = + Path, Path>>>>; + + // This evaluates to u64 at compile-time. + type TypeOfFieldX3X2X1 = >::Output; + + #[test] + fn gindex_basics() { + // This works but just shows compile-time field resolution. + let x: TypeOfFieldX3X2X1 = 0u64; + + // Gindex computation. + assert_eq!(>::gindex(1), 8); + assert_eq!(>::gindex(1), 9); + + // FIXME: Not sure if these values are correct + assert_eq!(>::gindex(1), 89); + assert_eq!( + as Resolve>::gindex(1), + 1049 + ); + } +} diff --git a/src/variable_list.rs b/src/variable_list.rs index 5dffad1..db97f26 100644 --- a/src/variable_list.rs +++ b/src/variable_list.rs @@ -228,6 +228,27 @@ where } } +impl tree_hash::prototype::MerkleProof for VariableList +where + T: tree_hash::TreeHash +{ + fn compute_proof_for_gindex(&self, gindex: usize) -> Result, tree_hash::prototype::Error>{ + if gindex < 2 { + return Err(tree_hash::prototype::Error::Oops); + } + + let adjusted_gindex = if gindex == 2 { + 1 + } else if gindex > 2 { + gindex - 2 + } else { + return Err(tree_hash::prototype::Error::Oops); + }; + + crate::tree_hash::generate_proof_for_vec::(&self.vec, adjusted_gindex) + } +} + impl ssz::Encode for VariableList where T: ssz::Encode, @@ -616,4 +637,97 @@ mod test { let result: Result, _> = serde_json::from_value(json); assert!(result.is_ok()); } + + #[test] + fn merkle_proof_basic() { + use tree_hash::prototype::MerkleProof; + use typenum::U4; + + let list: VariableList = VariableList::new(vec![1, 2, 3]).unwrap(); + + let proof = list.compute_proof_for_gindex(0); + assert!(proof.is_err()); + + let proof = list.compute_proof_for_gindex(1); + assert!(proof.is_err()); + + let proof = list.compute_proof_for_gindex(2); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + assert_eq!(proof.len(), 0); + } + + let proof = list.compute_proof_for_gindex(4); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + assert!(!proof.is_empty()); + } + } + + + #[test] + fn merkle_proof_complex_types() { + use tree_hash::prototype::MerkleProof; + use typenum::U4; + + // Create a list of composite types + let a1 = A { a: 1, b: 2 }; + let a2 = A { a: 3, b: 4 }; + let a3 = A { a: 5, b: 6 }; + let list: VariableList = VariableList::new(vec![a1, a2, a3]).unwrap(); + + // Test proof generation for complex types + let proof = list.compute_proof_for_gindex(4); + assert!(proof.is_ok()); + if let Ok(proof) = proof { + // Verify proof structure + assert!(!proof.is_empty()); + + // Test that all proof elements are valid Hash256 + for hash in proof { + assert_eq!(hash.len(), 32); + } + } + } + + + + #[test] + fn merkle_proof_tree_structure() { + use tree_hash::prototype::MerkleProof; + use typenum::U8; + + let list: VariableList = VariableList::new(vec![1, 2, 3, 4, 5]).unwrap(); + + let test_gindices = vec![2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16]; + + for gindex in test_gindices { + let proof = list.compute_proof_for_gindex(gindex); + assert!(proof.is_ok(), "Failed to generate proof for gindex {}", gindex); + + if let Ok(proof) = proof { + let adjusted_gindex = if gindex == 2 { + 1 + } else { + gindex - 2 + }; + + if adjusted_gindex == 1 { + assert_eq!(proof.len(), 0, "Proof should be empty for gindex {} (maps to root)", gindex); + } else { + assert!(!proof.is_empty(), "Proof should not be empty for gindex {}", gindex); + } + + let adjusted_gindex = if gindex == 2 { + 1 + } else { + gindex - 2 + }; + if adjusted_gindex > 1 { + let expected_depth = 64 - adjusted_gindex.leading_zeros() as usize - 1; + assert_eq!(proof.len(), expected_depth, "Incorrect proof length for gindex {}", gindex); + } + } + } + } }