Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .cargo/config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ test-r55 = "test --package r55"
test-e2e = "test --package r55 --test e2e"
test-erc20 = "test --package r55 --test erc20"
test-erc721 = "test --package r55 --test erc721"

test-univ2 = "test --package r55 --test uniswap-v2"

131 changes: 101 additions & 30 deletions contract-derive/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::error::Error;

use alloy_core::primitives::keccak256;
use alloy_dyn_abi::DynSolType;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -130,26 +128,58 @@ where
let (mut_methods, immut_methods): (Vec<MethodInfo>, Vec<MethodInfo>) =
methods.into_iter().partition(|m| m.is_mutable());

// Generate implementations
let mut_method_impls = mut_methods
.iter()
.map(|method| generate_method_impl(method, interface_style, true));
let immut_method_impls = immut_methods
.iter()
.map(|method| generate_method_impl(method, interface_style, false));
// Generate implementations and documentation for mutable methods
let mut mut_method_impls = Vec::new();
let mut mut_method_docs = Vec::new();
for method in &mut_methods {
let (impl_code, doc_code) = generate_method_impl(method, interface_style, true);
mut_method_impls.push(impl_code);
mut_method_docs.push(doc_code);
}

// Generate implementations and documentation for immutable methods
let mut immut_method_impls = Vec::new();
let mut immut_method_docs = Vec::new();
for method in &immut_methods {
let (impl_code, doc_code) = generate_method_impl(method, interface_style, false);
immut_method_impls.push(impl_code);
immut_method_docs.push(doc_code);
}

quote! {
use core::marker::PhantomData;
/// `Interface` is a wrapper type for `Address`, which allows to easily interact with the contract's bytecode. Automatically derives the `CallCtx`, but needs to be initialized using a builder pattern.
/// ```
/// // Can perform a staticcall to an ERC20
/// pub fn call_static(&self, token_addr: Address) {
/// let token = IERC20::new(token_addr).with_ctx(self); // IERC20<ReadOnly>
/// }
///
/// // Can perform a (mutable) call to an ERC20
/// pub fn call_mutable(&mut self, token_addr: Address) {
/// let mut token = IERC20::new(token_addr).with_ctx(self); // IERC20<ReadWrite>
/// }
/// ```
///
/// -------------------------------------------------------------------------------------------
///
/// Implementation methods, always available:
/// * fn `address`() -> `Address`; Returns the address of the underlying contract
///
/// Immutable methods, available on `StaticCtx` and `MutableCtx`:
#(#immut_method_docs)*
///
/// Mutable methods, only available on `MutableCtx`:
#(#mut_method_docs)*
pub struct #interface_name<C: CallCtx> {
address: Address,
_ctx: PhantomData<C>
_ctx: core::marker::PhantomData<C>
}

impl InitInterface for #interface_name<ReadOnly> {
fn new(address: Address) -> InterfaceBuilder<Self> {
InterfaceBuilder {
address,
_phantom: PhantomData
_phantom: core::marker::PhantomData
}
}
}
Expand All @@ -159,7 +189,7 @@ where
fn into_interface(self) -> #interface_name<C> {
#interface_name {
address: self.address,
_ctx: PhantomData
_ctx: core::marker::PhantomData
}
}
}
Expand All @@ -170,7 +200,7 @@ where
fn from_builder(builder: InterfaceBuilder<Self>) -> Self {
Self {
address: builder.address,
_ctx: PhantomData
_ctx: core::marker::PhantomData
}
}
}
Expand All @@ -195,7 +225,7 @@ fn generate_method_impl(
method: &MethodInfo,
interface_style: Option<InterfaceNamingStyle>,
is_mutable: bool,
) -> TokenStream {
) -> (TokenStream, TokenStream) {
let name = method.name;
let return_type = method.return_type;
let method_selector = u32::from_be_bytes(
Expand Down Expand Up @@ -239,9 +269,38 @@ fn generate_method_impl(
quote! { &self},
)
};
let wrapper_type = extract_wrapper_types(return_type);

// Generate documentation
let (arg_names_for_docs, arg_types_for_docs) = get_arg_props_skip_first(method);
let args_docs = if arg_names_for_docs.is_empty() {
String::new()
} else {
let args_with_types = arg_names_for_docs.iter().zip(arg_types_for_docs.iter())
.map(|(name, ty)| format!("{}: `{}`", name, quote!(#ty).to_string().replace(" ", "").replace(",", ", ")))
.collect::<Vec<_>>()
.join(", ");
args_with_types
};
let doc_return_type = match &wrapper_type {
WrapperType::Result(ok_type, err_type) => quote!(Result<#ok_type, #err_type>),
WrapperType::Option(inner_type) => quote!(Option<#inner_type>),
WrapperType::None => match return_type {
ReturnType::Default => quote!{Option<()>},
ReturnType::Type(_, ty) => quote!(Option<#ty>),
},
};

let doc_line = format!(
r#"* fn `{}`({}) -> `{}`;"#,
method.name,
args_docs,
doc_return_type.to_string().replace(" ", "").replace(",", ", "),
);
let doc_stream = syn::parse_quote!(#[doc = #doc_line]);

// Generate different implementations based on return type
match extract_wrapper_types(&method.return_type) {
let impl_stream = match wrapper_type {
// If `Result<T, E>` handle each individual type
WrapperType::Result(ok_type, err_type) => quote! {
pub fn #name(#self_param, #(#arg_names: #arg_types),*) -> Result<#ok_type, #err_type> {
Expand Down Expand Up @@ -313,7 +372,9 @@ fn generate_method_impl(
}
}
}
}
};

(impl_stream, doc_stream)
}

pub enum WrapperType {
Expand Down Expand Up @@ -435,16 +496,17 @@ pub fn rust_type_to_sol_type(ty: &Type) -> Result<DynSolType, &'static str> {
.trim_start_matches('B')
.parse()
.map_err(|_| "Invalid fixed bytes size")?;
if size > 0 && size <= 32 {
Ok(DynSolType::FixedBytes(size))
if size > 0 && size <= 256 {
Ok(DynSolType::FixedBytes(size / 8))
} else {
Err("Invalid fixed bytes size (between 1-32)")
}
}
// Fixed-size unsigned integers
u if u.starts_with('U') => {
u if u.to_lowercase().starts_with('u') => {
let size: usize = u
.trim_start_matches('U')
.to_lowercase()
.trim_start_matches('u')
.parse()
.map_err(|_| "Invalid uint size")?;
if size > 0 && size <= 256 && size % 8 == 0 {
Expand All @@ -454,9 +516,10 @@ pub fn rust_type_to_sol_type(ty: &Type) -> Result<DynSolType, &'static str> {
}
}
// Fixed-size signed integers
i if i.starts_with('I') => {
i if i.to_lowercase().starts_with('i') => {
let size: usize = i
.trim_start_matches('I')
.to_lowercase()
.trim_start_matches('i')
.parse()
.map_err(|_| "Invalid int size")?;
if size > 0 && size <= 256 && size % 8 == 0 {
Expand Down Expand Up @@ -548,15 +611,23 @@ pub fn generate_deployment_code(
Some(method) => {
let method_info = MethodInfo::from(method);
let (arg_names, arg_types) = get_arg_props_all(&method_info);
quote! {
impl #struct_name { #method }

// Get encoded constructor args
let calldata = eth_riscv_runtime::msg_data();
if arg_types.is_empty() {
quote! {
impl #struct_name { #method }
#struct_name::new();
}
} else {
quote! {
impl #struct_name { #method }

let (#(#arg_names),*) = <(#(#arg_types),*)>::abi_decode(&calldata, true)
.expect("Failed to decode constructor args");
#struct_name::new(#(#arg_names),*);
// Get encoded constructor args
let calldata = eth_riscv_runtime::msg_data();

let (#(#arg_names),*) = <(#(#arg_types),*)>::abi_decode(&calldata, true)
.expect("Failed to decode constructor args");
#struct_name::new(#(#arg_names),*);
}
}
}
None => quote! {
Expand Down
17 changes: 13 additions & 4 deletions contract-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ pub fn event_derive(input: TokenStream) -> TokenStream {

#[proc_macro_attribute]
pub fn show_streams(attr: TokenStream, item: TokenStream) -> TokenStream {
println!("attr: \"{}\"", attr.to_string());
println!("item: \"{}\"", item.to_string());
println!("attr: \"{}\"", attr);
println!("item: \"{}\"", item);
item
}

Expand All @@ -276,6 +276,7 @@ pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream {

let mut constructor = None;
let mut public_methods: Vec<&ImplItemMethod> = Vec::new();
let mut private_methods: Vec<&ImplItemMethod> = Vec::new();

// Iterate over the items in the impl block to find pub methods + constructor
for item in input.items.iter() {
Expand All @@ -284,10 +285,16 @@ pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream {
constructor = Some(method);
} else if let syn::Visibility::Public(_) = method.vis {
public_methods.push(method);
} else {
private_methods.push(method);
}
}
}

let inner_methods: Vec<_> = private_methods
.iter()
.map(|method| quote! { #method })
.collect();
let input_methods: Vec<_> = public_methods
.iter()
.map(|method| quote! { #method })
Expand Down Expand Up @@ -464,6 +471,7 @@ pub fn contract(_attr: TokenStream, item: TokenStream) -> TokenStream {
#emit_helper

impl #struct_name { #(#input_methods)* }
impl #struct_name { #(#inner_methods)* }
impl Contract for #struct_name {
fn call(&mut self) {
self.call_with_data(&msg_data());
Expand Down Expand Up @@ -588,11 +596,12 @@ pub fn storage(_attr: TokenStream, input: TokenStream) -> TokenStream {
let expanded = quote! {
#vis struct #name { #(#struct_fields,)* }

impl #name {
pub fn default() -> Self {
impl Default for #name {
fn default() -> Self {
Self { #(#init_fields,)* }
}
}
impl #name { pub fn address(&self) -> Address { eth_riscv_runtime::this() } }
};

TokenStream::from(expanded)
Expand Down
12 changes: 12 additions & 0 deletions eth-riscv-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ pub fn keccak256(offset: u64, size: u64) -> U256 {
U256::from_limbs([first, second, third, fourth])
}

pub fn this() -> Address {
let (first, second, third): (u64, u64, u64);
unsafe {
asm!("ecall", lateout("a0") first, lateout("a1") second, lateout("a2") third, in("t0") u8::from(Syscall::Address));
}
let mut bytes = [0u8; 20];
bytes[0..8].copy_from_slice(&first.to_be_bytes());
bytes[8..16].copy_from_slice(&second.to_be_bytes());
bytes[16..20].copy_from_slice(&third.to_be_bytes()[..4]);
Address::from_slice(&bytes)
}

pub fn msg_sender() -> Address {
let (first, second, third): (u64, u64, u64);
unsafe {
Expand Down
62 changes: 62 additions & 0 deletions eth-riscv-runtime/src/types/lock.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// TODO: use TLOAD/TSTORE rather than SLOAD/STORE once transient storage is implemented
use super::*;

/// A storage primitive that implements a reentrancy guard using the RAII pattern.
///
/// `Lock<E>` wraps a `Slot<bool>` to track lock state and uses a generic error type `E`
/// to provide type-safe error handling when attempting to acquire an already locked resource.
///
/// The `Lock` should be initialized in the contract constructor, by calling its `fn initialize()`.
///
/// The `LockGuard` returned by `fn acquire()` automatically releases the lock when it goes out of scope,
/// ensuring the lock is dropped even if the code returns early or panics.
#[derive(Default)]
pub struct Lock<E> {
unlocked: Slot<bool>,
_pd: PhantomData<E>,
}

impl<E> StorageLayout for Lock<E> {
fn allocate(first: u64, second: u64, third: u64, fourth: u64) -> Self {
Self {
unlocked: Slot::allocate(first, second, third, fourth),
_pd: PhantomData,
}
}
}

impl<E> Lock<E> {
/// Initialize a new lock in the unlocked state. Should only be created in the constructor.
pub fn initialize(&mut self) {
self.unlocked.write(true);
}

/// Attempts to acquire the lock, returning a guard that releases the lock when dropped.
/// When unable to acquire the lock, returns `locked_err`.
pub fn acquire(&mut self, locked_err: E) -> Result<LockGuard<E>, E> {
if !self.unlocked.read() {
return Err(locked_err);
}

self.unlocked.write(false);
Ok(LockGuard { slot_id: self.unlocked.id(), _pd: PhantomData })
}

/// Checks if the lock is currently unlocked
pub fn is_unlocked(&self) -> bool {
self.unlocked.read()
}
}

/// A guard that manages the locking state
pub struct LockGuard<E> {
slot_id: U256, // Store the key directly
_pd: PhantomData<E>,
}

impl<E> Drop for LockGuard<E> {
fn drop(&mut self) {
// Write `true` back directly using the stored key and the static __write method
<Slot<bool> as StorageStorable>::__write(self.slot_id, true);
}
}
Loading