diff --git a/Cargo.lock b/Cargo.lock index c806a0ef..e2b518b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1172,18 +1172,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.51" +version = "1.0.52" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6" +checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" dependencies = [ "unicode-ident", ] [[package]] name = "quote" -version = "1.0.18" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1feb54ed693b93a84e14094943b84b7c4eae204c512b7ccb95ab0c66d278ad1" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -1552,6 +1552,7 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" name = "starknet" version = "0.2.0" dependencies = [ + "crypto-bigint", "serde_json", "starknet-accounts", "starknet-contract", @@ -1598,6 +1599,7 @@ version = "0.2.0" dependencies = [ "base64", "criterion", + "crypto-bigint", "ethereum-types", "flate2", "hex", @@ -1669,6 +1671,7 @@ dependencies = [ name = "starknet-macros" version = "0.1.0" dependencies = [ + "quote", "starknet-core", "syn", ] diff --git a/Cargo.toml b/Cargo.toml index 264259ea..43981588 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ starknet-contract = { version = "0.1.0", path = "./starknet-contract" } starknet-signers = { version = "0.1.0", path = "./starknet-signers" } starknet-accounts = { version = "0.1.0", path = "./starknet-accounts" } starknet-macros = { version = "0.1.0", path = "./starknet-macros" } +crypto-bigint = { version = "0.4.9", default-features = false } [dev-dependencies] serde_json = "1.0.74" diff --git a/examples/decoder_example.rs b/examples/decoder_example.rs new file mode 100644 index 00000000..86cd5451 --- /dev/null +++ b/examples/decoder_example.rs @@ -0,0 +1,74 @@ +use starknet::{ + core::types::FieldElement, + providers::{Provider, SequencerGatewayProvider}, +}; +use starknet_core::{ + decoder::{decode, Address, Decode, ParamType, Token}, + types::{TransactionType, ValueOutOfRangeError}, +}; + +use crypto_bigint::U256; +use starknet_macros::Decode; + +#[derive(Debug)] +struct Uint { + low: U256, + high: U256, +} + +#[derive(Debug, Decode)] +struct Transfer { + from: Address, + to: Address, + amount: Uint, +} + +impl TryFrom<&Token> for Uint { + type Error = ValueOutOfRangeError; + fn try_from(value: &Token) -> Result { + if let Token::Tuple(v) = value { + Ok(Uint { + low: U256::from(&v[0]), + high: U256::from(&v[1]), + }) + } else { + Err(ValueOutOfRangeError) + } + } +} + +impl TryFrom> for Transfer { + type Error = ValueOutOfRangeError; + + fn try_from(value: Vec) -> Result { + let from = Address::try_from(&value[0])?; + let to = Address::try_from(&value[1])?; + let amount = Uint::try_from(&value[2])?; + Ok(Self { from, to, amount }) + } +} + +#[tokio::main] +async fn main() { + let provider = SequencerGatewayProvider::starknet_alpha_goerli(); + + let tx_hash = "0x03a4ce1bb249ed3f8b190012dc7ca0ab2caff155d1b81727aebf2bb7ee12b04b"; + let tx_hash = FieldElement::from_hex_be(tx_hash).unwrap(); + let tx = provider.get_transaction(tx_hash).await.unwrap(); + println!("tx: {tx:?}"); + + let tx_type = tx.r#type.unwrap(); + + let types = [ + ParamType::FieldElement, + ParamType::FieldElement, + ParamType::Tuple(2), + ]; + if let TransactionType::L1Handler(tx) = tx_type { + let decoded = decode(&types, &tx.calldata).unwrap(); + println!("decoded: {decoded:?}"); + + let transfer = Transfer::decode(&decoded); + println!("transfer: {transfer:?}"); + } +} diff --git a/starknet-core/Cargo.toml b/starknet-core/Cargo.toml index 11b6a12a..ae17ae32 100644 --- a/starknet-core/Cargo.toml +++ b/starknet-core/Cargo.toml @@ -29,6 +29,7 @@ serde_json = { version = "1.0.74", features = ["arbitrary_precision"] } serde_with = "2.2.0" sha3 = "0.10.0" thiserror = "1.0.30" +crypto-bigint = { version = "0.4.9", default-features = false } [dev-dependencies] criterion = { version = "0.4.0", default-features = false } diff --git a/starknet-core/src/decoder.rs b/starknet-core/src/decoder.rs new file mode 100644 index 00000000..199b2f14 --- /dev/null +++ b/starknet-core/src/decoder.rs @@ -0,0 +1,441 @@ +use crate::types::FieldElement; + +pub enum ParamType { + FieldElement, + Array, + Tuple(usize), +} + +#[derive(PartialEq, Eq, Debug)] +pub enum Token { + FieldElement(FieldElement), + Array(Vec), + Tuple(Vec), +} + +#[derive(PartialEq, Eq, Debug)] +pub struct DecodeResult { + token: Token, + new_offset: usize, +} + +mod decoder_error { + + #[derive(Debug, PartialEq)] + pub enum DecoderError { + InvalidLength, + ValueOutOfRange, + } + + #[cfg(feature = "std")] + impl std::error::Error for DecoderError {} + + impl core::fmt::Display for DecoderError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::InvalidLength => write!(f, "invalid length"), + Self::ValueOutOfRange => write!(f, "number out of range"), + } + } + } +} +pub use decoder_error::DecoderError; +use starknet_ff::ValueOutOfRangeError; + +fn validate_data_length( + data: &[FieldElement], + offset: usize, + len: usize, +) -> Result<(), DecoderError> { + if offset + len > data.len() { + Err(DecoderError::InvalidLength) + } else { + Ok(()) + } +} + +fn decode_param( + param: &ParamType, + data: &[FieldElement], + offset: usize, + validate: bool, +) -> Result { + match *param { + ParamType::FieldElement => { + if validate { + validate_data_length(data, offset, 1)?; + } + + Ok(DecodeResult { + token: Token::FieldElement(data[offset]), + new_offset: offset + 1, + }) + } + ParamType::Array => { + if validate { + validate_data_length(data, offset, 1)?; + } + + let size: usize = + u32::try_from(data[offset]).map_err(|_| DecoderError::ValueOutOfRange)? as usize; + + if validate { + validate_data_length(data, offset, size + 1)?; + } + + Ok(DecodeResult { + token: Token::Array(data[(offset + 1)..(offset + size + 1)].to_vec()), + new_offset: offset + size + 1, + }) + } + ParamType::Tuple(size) => { + if validate { + validate_data_length(data, offset, size)?; + } + + Ok(DecodeResult { + token: Token::Tuple(data[offset..(offset + size)].to_vec()), + new_offset: offset + size, + }) + } + } +} + +fn decode_impl( + types: &[ParamType], + data: &[FieldElement], + offset: usize, + validate: bool, +) -> Result, DecoderError> { + let mut tokens = vec![]; + let mut offset = offset; + + for param in types { + let res = decode_param(param, data, offset, validate)?; + offset = res.new_offset; + tokens.push(res.token); + } + + Ok(tokens) +} + +pub fn decode(types: &[ParamType], data: &[FieldElement]) -> Result, DecoderError> { + decode_impl(types, data, 0, true) +} + +impl TryFrom<&Token> for u32 { + // TODO: add an error type to represent invalid token types like Array and Tuple ! + type Error = ValueOutOfRangeError; + fn try_from(value: &Token) -> Result { + match value { + Token::FieldElement(felt) => u32::try_from(*felt), + _ => Err(ValueOutOfRangeError), + } + } +} + +#[derive(Debug)] +pub struct Address(String); + +impl TryFrom<&Token> for Address { + // TODO: add an error type to represent invalid token types like Array and Tuple ! + type Error = ValueOutOfRangeError; + fn try_from(value: &Token) -> Result { + match value { + Token::FieldElement(felt) => Ok(Address(format!("{:#064x}", felt))), + _ => Err(ValueOutOfRangeError), + } + } +} + +impl TryFrom<&Token> for String { + // TODO: add an error type to represent invalid token types like Array and Tuple ! + type Error = ValueOutOfRangeError; + fn try_from(value: &Token) -> Result { + match value { + Token::FieldElement(felt) => String::try_from(*felt), + _ => Err(ValueOutOfRangeError), + } + } +} + +impl TryFrom for Vec { + type Error = ValueOutOfRangeError; + fn try_from(value: Token) -> Result { + match value { + Token::Array(v) => v.iter().map(|felt| u32::try_from(*felt)).collect(), + _ => Err(ValueOutOfRangeError), + } + } +} + +pub trait Decode { + fn decode(tokens: &[Token]) -> Self; +} + +#[cfg(test)] +mod test { + use starknet_crypto::FieldElement; + + use super::{decode, decode_impl, decode_param, DecodeResult, DecoderError, ParamType, Token}; + + #[test] + fn decode_param_field_element() -> Result<(), DecoderError> { + let result = decode_param(&ParamType::FieldElement, &[FieldElement::ONE], 0, true)?; + let expected_result = DecodeResult { + token: Token::FieldElement(FieldElement::ONE), + new_offset: 1, + }; + assert_eq!(result, expected_result); + Ok(()) + } + + #[test] + fn decode_param_field_element_empty_data() { + let result = decode_param(&ParamType::FieldElement, &[], 0, true); + let expected_result = Err(DecoderError::InvalidLength); + assert_eq!(result, expected_result) + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn decode_param_field_element_empty_data_no_validate() { + decode_param(&ParamType::FieldElement, &[], 0, false).unwrap(); + } + + #[test] + fn decode_param_array() -> Result<(), DecoderError> { + let result = decode_param( + &ParamType::Array, + &[FieldElement::TWO, FieldElement::ONE, FieldElement::THREE], + 0, + true, + )?; + let expected_result = DecodeResult { + token: Token::Array(vec![FieldElement::ONE, FieldElement::THREE]), + new_offset: 3, + }; + assert_eq!(result, expected_result); + Ok(()) + } + + #[test] + fn decode_param_array_empty_data() { + let result = decode_param(&ParamType::Array, &[], 0, true); + let expected_result = Err(DecoderError::InvalidLength); + assert_eq!(result, expected_result) + } + + #[test] + #[should_panic(expected = "index out of bounds")] + fn decode_param_array_empty_data_no_validate() { + decode_param(&ParamType::Array, &[], 0, false).unwrap(); + } + + #[test] + fn decode_param_array_insufficient_data() { + let result = decode_param( + &ParamType::Array, + &[FieldElement::TWO, FieldElement::THREE], + 0, + true, + ); + + let expected_result = Err(DecoderError::InvalidLength); + assert_eq!(result, expected_result) + } + + #[test] + #[should_panic(expected = "range end index 3 out of range for slice of length 2")] + fn decode_param_array_insufficient_data_no_validate() { + decode_param( + &ParamType::Array, + &[FieldElement::TWO, FieldElement::THREE], + 0, + false, + ) + .unwrap(); + } + + #[test] + fn decode_param_array_invalid_size() { + let result = decode_param( + &ParamType::Array, + &[FieldElement::MAX, FieldElement::ONE, FieldElement::THREE], + 0, + true, + ); + let expected_result = Err(DecoderError::ValueOutOfRange); + assert_eq!(result, expected_result); + } + + #[test] + #[should_panic(expected = "ValueOutOfRange")] + fn decode_param_array_invalid_size_no_validate() { + decode_param( + &ParamType::Array, + &[FieldElement::MAX, FieldElement::ONE, FieldElement::THREE], + 0, + true, + ) + .unwrap(); + } + + #[test] + fn decode_param_tuple() -> Result<(), DecoderError> { + let result = decode_param( + &ParamType::Tuple(2), + &[FieldElement::TWO, FieldElement::THREE], + 0, + true, + )?; + let expected_result = DecodeResult { + token: Token::Tuple(vec![FieldElement::TWO, FieldElement::THREE]), + new_offset: 2, + }; + assert_eq!(result, expected_result); + Ok(()) + } + + #[test] + fn decode_param_tuple_empty_data() { + let result = decode_param(&ParamType::Tuple(2), &[], 0, true); + let expected_result = Err(DecoderError::InvalidLength); + assert_eq!(result, expected_result); + } + + #[test] + #[should_panic(expected = "range end index 2 out of range for slice of length 0")] + fn decode_param_tuple_empty_data_no_validate() { + decode_param(&ParamType::Tuple(2), &[], 0, false).unwrap(); + } + + #[test] + fn decode_param_tuple_insufficient_data() { + let result = decode_param( + &ParamType::Tuple(3), + &[FieldElement::TWO, FieldElement::THREE, FieldElement::ONE], + 1, + true, + ); + let expected_result = Err(DecoderError::InvalidLength); + assert_eq!(result, expected_result); + } + + #[test] + #[should_panic(expected = "range end index 4 out of range for slice of length 3")] + fn decode_param_tuple_insufficient_data_no_validate() { + decode_param( + &ParamType::Tuple(3), + &[FieldElement::TWO, FieldElement::THREE, FieldElement::ONE], + 1, + false, + ) + .unwrap(); + } + + #[test] + fn decode_data() -> Result<(), DecoderError> { + let types = [ + ParamType::FieldElement, + ParamType::Array, + ParamType::Tuple(2), + ]; + + let data = [ + FieldElement::ONE, // field element + FieldElement::TWO, // array length + FieldElement::ONE, // first element of the array + FieldElement::THREE, // second element of the array + FieldElement::TWO, // first element of the tuple + FieldElement::THREE, // second element of the tuple + ]; + + let expected_result = vec![ + Token::FieldElement(FieldElement::ONE), + Token::Array(vec![FieldElement::ONE, FieldElement::THREE]), + Token::Tuple(vec![FieldElement::TWO, FieldElement::THREE]), + ]; + + let result = decode(&types, &data)?; + + assert_eq!(result, expected_result); + Ok(()) + } + + #[test] + fn decode_data_exceeds_types() -> Result<(), DecoderError> { + let types = [ParamType::FieldElement, ParamType::Tuple(3)]; + + let data = [ + FieldElement::ONE, // field element + FieldElement::TWO, // first element of the tuple + FieldElement::ONE, // second element of the tuple + FieldElement::THREE, // third element of the tuple + FieldElement::TWO, // exceeded + FieldElement::THREE, // exceeded + ]; + + let expected_result = vec![ + Token::FieldElement(FieldElement::ONE), + Token::Tuple(vec![ + FieldElement::TWO, + FieldElement::ONE, + FieldElement::THREE, + ]), + ]; + + let result = decode(&types, &data)?; + + assert_eq!(result, expected_result); + Ok(()) + } + + #[test] + fn decode_missing_data() { + let types = [ + ParamType::FieldElement, + ParamType::Array, + ParamType::Tuple(2), + ParamType::FieldElement, + ]; + + let data = [ + FieldElement::ONE, // field element + FieldElement::TWO, // array length + FieldElement::ONE, // first element of the array + FieldElement::THREE, // second element of the array + FieldElement::TWO, // first element of the tuple + FieldElement::THREE, // second element of the tuple + // missing last field element + ]; + + let result = decode(&types, &data); + let expected_result = Err(DecoderError::InvalidLength); + + assert_eq!(result, expected_result); + } + + #[test] + #[should_panic(expected = "index out of bounds: the len is 6 but the index is 6")] + fn decode_missing_data_no_validate() { + let types = [ + ParamType::FieldElement, + ParamType::Array, + ParamType::Tuple(2), + ParamType::FieldElement, + ]; + + let data = [ + FieldElement::ONE, // field element + FieldElement::TWO, // array length + FieldElement::ONE, // first element of the array + FieldElement::THREE, // second element of the array + FieldElement::TWO, // first element of the tuple + FieldElement::THREE, // second element of the tuple + // missing last field element + ]; + + decode_impl(&types, &data, 0, false).unwrap(); + } +} diff --git a/starknet-core/src/lib.rs b/starknet-core/src/lib.rs index 3c715981..b0e881a0 100644 --- a/starknet-core/src/lib.rs +++ b/starknet-core/src/lib.rs @@ -10,3 +10,5 @@ pub mod crypto; pub mod utils; pub mod chain_id; + +pub mod decoder; diff --git a/starknet-ff/src/lib.rs b/starknet-ff/src/lib.rs index 6f15693f..e8cb0d25 100644 --- a/starknet-ff/src/lib.rs +++ b/starknet-ff/src/lib.rs @@ -13,7 +13,7 @@ use core::{ use fr::FrParameters; use ark_ff::{fields::Fp256, BigInteger, BigInteger256, Field, PrimeField, SquareRootField}; -use crypto_bigint::{CheckedAdd, CheckedMul, Zero, U256}; +use crypto_bigint::{CheckedAdd, CheckedMul, Zero, U256, U512}; mod fr; @@ -616,6 +616,30 @@ impl TryFrom for u8 { } } +impl TryFrom for bool { + type Error = ValueOutOfRangeError; + fn try_from(value: FieldElement) -> Result { + if value == FieldElement::ONE { + return Ok(true); + } else if value == FieldElement::ZERO { + return Ok(false); + } else { + return Err(ValueOutOfRangeError); + } + } +} + +impl TryFrom for String { + type Error = ValueOutOfRangeError; + fn try_from(value: FieldElement) -> Result { + let be = value.to_bytes_be(); + String::from_utf8(be.to_vec()) + .map_err(|_| ValueOutOfRangeError) + .map(|s| s.trim_start_matches("\0").to_owned()) + .map(String::from) + } +} + impl TryFrom for u16 { type Error = ValueOutOfRangeError; @@ -946,4 +970,15 @@ mod tests { ); } } + + #[test] + #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)] + fn test_string() { + let strs = ["Hello world", "1924", "0x123", "0"]; + + for str in strs.into_iter() { + let felt = FieldElement::from_byte_slice_be(str.as_bytes()).unwrap(); + assert_eq!(String::try_from(felt).unwrap(), str) + } + } } diff --git a/starknet-macros/Cargo.toml b/starknet-macros/Cargo.toml index 4abd7a4b..1a334e6b 100644 --- a/starknet-macros/Cargo.toml +++ b/starknet-macros/Cargo.toml @@ -16,6 +16,7 @@ keywords = ["ethereum", "starknet", "web3"] proc-macro = true [dependencies] +quote = "1.0.26" starknet-core = { version = "0.2.0", path = "../starknet-core" } syn = "1.0.96" diff --git a/starknet-macros/src/lib.rs b/starknet-macros/src/lib.rs index 1d887834..309ce208 100644 --- a/starknet-macros/src/lib.rs +++ b/starknet-macros/src/lib.rs @@ -1,9 +1,12 @@ use proc_macro::TokenStream; +use quote::quote; use starknet_core::{ types::FieldElement, utils::{cairo_short_string_to_felt, get_selector_from_name}, }; -use syn::{parse_macro_input, LitStr}; +use syn::{ + parse_macro_input, punctuated::Punctuated, token::Comma, Data, DataStruct, Field, Ident, LitStr, +}; #[proc_macro] pub fn selector(input: TokenStream) -> TokenStream { @@ -124,3 +127,61 @@ fn field_element_path() -> &'static str { fn field_element_path() -> &'static str { "::starknet::core::types::FieldElement" } + +#[proc_macro_derive(Decode)] +pub fn decode_macro_derive(input: TokenStream) -> TokenStream { + // Construct a representation of Rust code as a syntax tree + // that we can manipulate + let ast = syn::parse(input).unwrap(); + + // Build the trait implementation + impl_decode_macro(&ast) +} + +fn impl_decode_macro(ast: &syn::DeriveInput) -> TokenStream { + if let Data::Struct(DataStruct { + fields: syn::Fields::Named(ref fields), + .. + }) = ast.data + { + impl_decode_macro_for_struct(&ast.ident, &fields.named) + } else { + // TODO: use abort_call_site instead + panic!("Decode only supports non-tuple structs") + } +} + +fn impl_decode_macro_for_struct(name: &Ident, fields: &Punctuated) -> TokenStream { + // Generate Decode implementation for a struct given its name and fields + // For a struct: + // struct Transfer { + // from: Address, + // to: Address, + // amount: Uint, + // } + // Generates: + // impl Decode for Transfer { + // fn decode(tokens: &[Token]) -> Self { + // Transfer { + // from: Address::try_from(&tokens[0]).unwrap(), + // to: Address::try_from(&tokens[1]).unwrap(), + // amount: Uint::try_from(&tokens[2]).unwrap(), + // } + // } + // } + + let fields = fields.iter().enumerate().map(|(i, field)| { + let field_name = field.ident.as_ref().unwrap(); + let field_type = &field.ty; + quote!( #field_name: #field_type::try_from(&tokens[#i]).unwrap() ) + }); + + let gen = quote! { + impl Decode for #name { + fn decode(tokens: &[Token]) -> Self { + #name { #( #fields ),* } + } + } + }; + gen.into() +}