Skip to content

Commit 6849818

Browse files
committed
Added option for raw_input attribute on entrypoints.
Additionally added a proxy-style example contract.
1 parent cea2157 commit 6849818

30 files changed

+2251
-110
lines changed

crates/cairo-lang-starknet-classes/src/casm_contract_class_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ fn test_casm_contract_from_contract_class_failure(name: &str) {
4646
#[test_case("with_erc20__erc20_contract")]
4747
#[test_case("with_ownable__ownable_balance")]
4848
#[test_case("ownable_erc20__ownable_erc20_contract")]
49+
#[test_case("proxy__proxy_contract")]
4950
#[test_case("upgradable_counter__counter_contract")]
5051
#[test_case("mintable__mintable_erc20_ownable")]
5152
#[test_case("multi_component__contract_with_4_components")]

crates/cairo-lang-starknet-classes/src/compiled_class_hash_test_data/contracts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,19 @@ multi_component__contract_with_4_components
235235

236236
//! > legacy_compiled_class_hash
237237
671a3f37217861f0b0033f92e8f32969c5487bb583522494e4d2616386e155a
238+
239+
//! > ==========================================================================
240+
241+
//! > Proxy contract
242+
243+
//! > test_runner_name
244+
test_compiled_class_hash
245+
246+
//! > compiled_class
247+
proxy__proxy_contract
248+
249+
//! > compiled_class_hash
250+
7090b8eab257dfac2afe78ae324223c695b0dfbb8ac5e236a2ae975dba7cefc
251+
252+
//! > legacy_compiled_class_hash
253+
7bf7f6ee4ddabf49a20b7c3c62f1428d0433cda950e74c39821f9e0d9e9fcdc

crates/cairo-lang-starknet-classes/src/contract_class_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ fn test_serialization() {
6161
#[test_case("with_erc20__erc20_contract")]
6262
#[test_case("with_ownable__ownable_balance")]
6363
#[test_case("ownable_erc20__ownable_erc20_contract")]
64+
#[test_case("proxy__proxy_contract")]
6465
#[test_case("upgradable_counter__counter_contract")]
6566
#[test_case("mintable__mintable_erc20_ownable")]
6667
#[test_case("multi_component__contract_with_4_components")]

crates/cairo-lang-starknet-classes/src/felt252_serde_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::test_utils::get_example_file_path;
1919
#[test_case("with_erc20")]
2020
#[test_case("with_ownable")]
2121
#[test_case("ownable_erc20")]
22+
#[test_case("proxy__proxy_contract")]
2223
#[test_case("upgradable_counter")]
2324
#[test_case("mintable")]
2425
#[test_case("multi_component__contract_with_4_components")]

crates/cairo-lang-starknet/cairo_level_tests/contracts.cairo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod mintable;
88
pub mod multi_component;
99
mod new_syntax_test_contract;
1010
mod ownable_erc20;
11+
mod proxy;
1112
mod storage_accesses;
1213
pub mod test_contract;
1314
mod token_bridge;
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#[starknet::contract]
2+
mod proxy_contract {
3+
use starknet::storage::StoragePointerReadAccess;
4+
use starknet::syscalls::library_call_syscall;
5+
use starknet::{ClassHash, SyscallResultTrait};
6+
7+
#[storage]
8+
struct Storage {
9+
forward: ClassHash,
10+
}
11+
12+
#[raw_input]
13+
#[raw_output]
14+
#[external(v0)]
15+
fn foo(ref self: ContractState, data: Span<felt252>) -> Span<felt252> {
16+
library_call_syscall(self.forward.read(), selector!("foo"), data).unwrap_syscall()
17+
}
18+
19+
#[raw_input]
20+
#[raw_output]
21+
#[external(v0)]
22+
fn bar(ref self: ContractState, data: Span<felt252>) -> Span<felt252> {
23+
library_call_syscall(self.forward.read(), selector!("bar"), data).unwrap_syscall()
24+
}
25+
}

crates/cairo-lang-starknet/src/compile_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ use crate::test_utils::{get_example_file_path, get_test_contract};
2222
#[test_case("with_ownable_mini::ownable_mini_contract")]
2323
#[test_case("with_erc20_mini::erc20_mini_contract")]
2424
#[test_case("ownable_erc20::ownable_erc20_contract")]
25+
#[test_case("proxy::proxy_contract")]
2526
#[test_case("upgradable_counter::counter_contract")]
2627
#[test_case("mintable::mintable_erc20_ownable")]
2728
#[test_case("multi_component::contract_with_4_components")]

crates/cairo-lang-starknet/src/plugin/consts.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pub const L1_HANDLER_ATTR: &str = "l1_handler";
3838
pub const CONSTRUCTOR_ATTR: &str = "constructor";
3939
pub const CONSTRUCTOR_NAME: &str = "constructor";
4040
pub(super) const RAW_OUTPUT_ATTR: &str = "raw_output";
41+
pub(super) const RAW_INPUT_ATTR: &str = "raw_input";
4142
pub const EMBEDDABLE_AS_ATTR: &str = "embeddable_as";
4243
pub const COMPONENT_INLINE_MACRO: &str = "component";
4344
pub const HAS_COMPONENT_TRAIT: &str = "HasComponent";

crates/cairo-lang-starknet/src/plugin/entry_point.rs

Lines changed: 62 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use salsa::Database;
1717
use super::consts::{
1818
CONSTRUCTOR_ATTR, CONSTRUCTOR_MODULE, CONSTRUCTOR_NAME, EXTERNAL_ATTR, EXTERNAL_MODULE,
1919
IMPLICIT_PRECEDENCE, L1_HANDLER_ATTR, L1_HANDLER_FIRST_PARAM_NAME, L1_HANDLER_MODULE,
20-
RAW_OUTPUT_ATTR, WRAPPER_PREFIX,
20+
RAW_INPUT_ATTR, RAW_OUTPUT_ATTR, WRAPPER_PREFIX,
2121
};
2222
use super::utils::{AstPathExtract, ParamEx, find_v0_attribute, maybe_strip_underscore};
2323

@@ -205,8 +205,6 @@ fn generate_entry_point_wrapper<'db>(
205205
let sig_params = sig.parameters(db);
206206
let mut params = sig_params.elements(db).enumerate();
207207
let mut diagnostics = vec![];
208-
let mut arg_names = Vec::new();
209-
let mut arg_definitions = Vec::new();
210208
let mut ref_appends = Vec::new();
211209

212210
let Some((0, first_param)) = params.next() else {
@@ -228,38 +226,65 @@ fn generate_entry_point_wrapper<'db>(
228226
// TODO(spapini): Check modifiers and type.
229227

230228
let raw_output = function.has_attr(db, RAW_OUTPUT_ATTR);
231-
for (param_idx, param) in params {
232-
let arg_name = format!("__arg_{}", param.name(db).text(db).long(db));
233-
let arg_type_ast =
234-
extract_matches!(param.type_clause(db), OptionTypeClause::TypeClause).ty(db);
235-
let type_name = arg_type_ast.as_syntax_node().get_text_without_trivia(db).long(db);
236-
237-
let is_ref = param.is_ref_param(db);
238-
if raw_output && is_ref {
229+
let (arg_processing, args_for_wrapped_function) = if let Some(raw_input_attr) =
230+
function.find_attr(db, RAW_INPUT_ATTR)
231+
{
232+
if params.len() != 1 {
239233
diagnostics.push(PluginDiagnostic::error(
240-
param.modifiers(db).stable_ptr(db),
241-
format!("`{RAW_OUTPUT_ATTR}` functions cannot have `ref` parameters."),
234+
sig.parameters(db).stable_ptr(db).untyped(),
235+
format!(
236+
"`{RAW_INPUT_ATTR}` functions must have a single parameter for calldata of \
237+
type `Span::<felt252>`."
238+
),
242239
));
243240
}
244-
let ref_modifier = if is_ref { "ref " } else { "" };
245-
arg_names.push(format!("{ref_modifier}{arg_name}"));
246-
let mut_modifier = if is_ref { "mut " } else { "" };
247-
let arg_definition = formatdoc!(
248-
"
241+
("".to_string(), RewriteNode::mapped_text("data", db, &raw_input_attr))
242+
} else {
243+
let (definitions, arg_names): (Vec<_>, Vec<_>) = params
244+
.map(|(param_idx, param)| {
245+
let arg_name = format!("__arg_{}", param.name(db).text(db).long(db));
246+
let arg_type_ast =
247+
extract_matches!(param.type_clause(db), OptionTypeClause::TypeClause).ty(db);
248+
let type_name = arg_type_ast.as_syntax_node().get_text_without_trivia(db).long(db);
249+
250+
let is_ref = param.is_ref_param(db);
251+
if is_ref {
252+
ref_appends.push(RewriteNode::Text(format!(
253+
"\n core::serde::Serde::<{type_name}>::serialize(@{arg_name}, \
254+
ref arr);"
255+
)));
256+
if raw_output {
257+
diagnostics.push(PluginDiagnostic::error(
258+
param.modifiers(db).stable_ptr(db),
259+
format!("`{RAW_OUTPUT_ATTR}` functions cannot have `ref` parameters."),
260+
));
261+
}
262+
}
263+
264+
let ref_modifier = if is_ref { "ref " } else { "" };
265+
let mut_modifier = if is_ref { "mut " } else { "" };
266+
(
267+
formatdoc!(
268+
"
249269
let {mut_modifier}{arg_name} = core::option::OptionTraitImpl::expect(
250270
core::serde::Serde::<{type_name}>::deserialize(ref data),
251271
'Failed to deserialize param #{param_idx}'
252272
);"
253-
);
254-
arg_definitions.push(arg_definition);
255-
256-
if is_ref {
257-
ref_appends.push(RewriteNode::Text(format!(
258-
"\n core::serde::Serde::<{type_name}>::serialize(@{arg_name}, ref arr);"
259-
)));
260-
}
261-
}
262-
let arg_names_str = arg_names.join(", ");
273+
),
274+
format!("{ref_modifier}{arg_name}"),
275+
)
276+
})
277+
.unzip();
278+
(
279+
definitions
280+
.into_iter()
281+
.chain(["assert(core::array::SpanTrait::is_empty(data), 'Input too long for \
282+
arguments');"
283+
.to_string()])
284+
.join("\n "),
285+
RewriteNode::Text(arg_names.join(", ")),
286+
)
287+
};
263288

264289
let ret_ty = sig.ret_ty(db);
265290
let (let_res, append_res, return_ty_is_felt252_span, ret_type_ptr) = match &ret_ty {
@@ -292,22 +317,23 @@ fn generate_entry_point_wrapper<'db>(
292317
}
293318

294319
let contract_state_arg = if is_snapshot { "@contract_state" } else { "ref contract_state" };
295-
let output_handling_string = if raw_output {
296-
format!("$wrapped_function_path$({contract_state_arg}, {arg_names_str})")
320+
let call_and_output_handling_string = if raw_output {
321+
format!("$wrapped_function_path$({contract_state_arg}, $args_for_wrapped_function$)")
297322
} else {
298323
formatdoc! {"
299-
{let_res}$wrapped_function_path$({contract_state_arg}, {arg_names_str});
324+
{let_res}$wrapped_function_path$({contract_state_arg}, $args_for_wrapped_function$);
300325
let mut arr = ArrayTrait::new();
301326
// References.$ref_appends$
302327
// Result.{append_res}
303328
core::array::ArrayTrait::span(@arr)"
304329
}
305330
};
306331

307-
let output_handling = RewriteNode::interpolate_patched(
308-
&output_handling_string,
332+
let call_and_output_handling = RewriteNode::interpolate_patched(
333+
&call_and_output_handling_string,
309334
&[
310335
("wrapped_function_path".to_string(), wrapped_function_path),
336+
("args_for_wrapped_function".to_string(), args_for_wrapped_function),
311337
("ref_appends".to_string(), RewriteNode::new_modified(ref_appends)),
312338
]
313339
.into(),
@@ -317,7 +343,6 @@ fn generate_entry_point_wrapper<'db>(
317343
IMPLICIT_PRECEDENCE.iter().join(", ")
318344
}));
319345

320-
let arg_definitions = RewriteNode::Text(arg_definitions.join("\n "));
321346
Ok(RewriteNode::interpolate_patched(
322347
&formatdoc! {"
323348
#[doc(hidden)]
@@ -328,20 +353,18 @@ fn generate_entry_point_wrapper<'db>(
328353
let Some(_) = core::gas::withdraw_gas() else {{
329354
core::panic_with_felt252('Out of gas');
330355
}};
331-
$arg_definitions$
332-
assert(core::array::SpanTrait::is_empty(data), 'Input too long for arguments');
356+
{arg_processing}
333357
let Some(_) = core::gas::withdraw_gas_all(core::gas::get_builtin_costs()) else {{
334358
core::panic_with_felt252('Out of gas');
335359
}};
336360
let mut contract_state = {unsafe_new_contract_state_prefix}unsafe_new_contract_state();
337-
$output_handling$
361+
$call_and_output_handling$
338362
}}
339363
"},
340364
&[
341365
("wrapper_function_name".to_string(), wrapper_function_name),
342366
("generic_params".to_string(), generic_params),
343-
("output_handling".to_string(), output_handling),
344-
("arg_definitions".to_string(), arg_definitions),
367+
("call_and_output_handling".to_string(), call_and_output_handling),
345368
("implicit_precedence".to_string(), implicit_precedence),
346369
]
347370
.into(),

crates/cairo-lang-starknet/src/plugin/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ impl MacroPlugin for StarknetPlugin {
8888
SmolStrId::from(db, L1_HANDLER_ATTR),
8989
SmolStrId::from(db, NESTED_ATTR),
9090
SmolStrId::from(db, RAW_OUTPUT_ATTR),
91+
SmolStrId::from(db, RAW_INPUT_ATTR),
9192
SmolStrId::from(db, STORAGE_ATTR),
9293
SmolStrId::from(db, SUBSTORAGE_ATTR),
9394
]

0 commit comments

Comments
 (0)