@@ -17,7 +17,7 @@ use salsa::Database;
1717use super :: consts:: {
1818 CONSTRUCTOR_ATTR , CONSTRUCTOR_MODULE , CONSTRUCTOR_NAME , EXTERNAL_ATTR , EXTERNAL_MODULE ,
1919 IMPLICIT_PRECEDENCE , L1_HANDLER_ATTR , L1_HANDLER_FIRST_PARAM_NAME , L1_HANDLER_MODULE ,
20- RAW_OUTPUT_ATTR , WRAPPER_PREFIX ,
20+ RAW_INPUT_ATTR , RAW_OUTPUT_ATTR , WRAPPER_PREFIX ,
2121} ;
2222use super :: utils:: { AstPathExtract , ParamEx , find_v0_attribute, maybe_strip_underscore} ;
2323
@@ -205,8 +205,6 @@ fn generate_entry_point_wrapper<'db>(
205205 let sig_params = sig. parameters ( db) ;
206206 let mut params = sig_params. elements ( db) . enumerate ( ) ;
207207 let mut diagnostics = vec ! [ ] ;
208- let mut arg_names = Vec :: new ( ) ;
209- let mut arg_definitions = Vec :: new ( ) ;
210208 let mut ref_appends = Vec :: new ( ) ;
211209
212210 let Some ( ( 0 , first_param) ) = params. next ( ) else {
@@ -228,38 +226,65 @@ fn generate_entry_point_wrapper<'db>(
228226 // TODO(spapini): Check modifiers and type.
229227
230228 let raw_output = function. has_attr ( db, RAW_OUTPUT_ATTR ) ;
231- for ( param_idx, param) in params {
232- let arg_name = format ! ( "__arg_{}" , param. name( db) . text( db) . long( db) ) ;
233- let arg_type_ast =
234- extract_matches ! ( param. type_clause( db) , OptionTypeClause :: TypeClause ) . ty ( db) ;
235- let type_name = arg_type_ast. as_syntax_node ( ) . get_text_without_trivia ( db) . long ( db) ;
236-
237- let is_ref = param. is_ref_param ( db) ;
238- if raw_output && is_ref {
229+ let ( arg_processing, args_for_wrapped_function) = if let Some ( raw_input_attr) =
230+ function. find_attr ( db, RAW_INPUT_ATTR )
231+ {
232+ if params. len ( ) != 1 {
239233 diagnostics. push ( PluginDiagnostic :: error (
240- param. modifiers ( db) . stable_ptr ( db) ,
241- format ! ( "`{RAW_OUTPUT_ATTR}` functions cannot have `ref` parameters." ) ,
234+ sig. parameters ( db) . stable_ptr ( db) . untyped ( ) ,
235+ format ! (
236+ "`{RAW_INPUT_ATTR}` functions must have a single parameter for calldata of \
237+ type `Span::<felt252>`."
238+ ) ,
242239 ) ) ;
243240 }
244- let ref_modifier = if is_ref { "ref " } else { "" } ;
245- arg_names. push ( format ! ( "{ref_modifier}{arg_name}" ) ) ;
246- let mut_modifier = if is_ref { "mut " } else { "" } ;
247- let arg_definition = formatdoc ! (
248- "
241+ ( "" . to_string ( ) , RewriteNode :: mapped_text ( "data" , db, & raw_input_attr) )
242+ } else {
243+ let ( definitions, arg_names) : ( Vec < _ > , Vec < _ > ) = params
244+ . map ( |( param_idx, param) | {
245+ let arg_name = format ! ( "__arg_{}" , param. name( db) . text( db) . long( db) ) ;
246+ let arg_type_ast =
247+ extract_matches ! ( param. type_clause( db) , OptionTypeClause :: TypeClause ) . ty ( db) ;
248+ let type_name = arg_type_ast. as_syntax_node ( ) . get_text_without_trivia ( db) . long ( db) ;
249+
250+ let is_ref = param. is_ref_param ( db) ;
251+ if is_ref {
252+ ref_appends. push ( RewriteNode :: Text ( format ! (
253+ "\n core::serde::Serde::<{type_name}>::serialize(@{arg_name}, \
254+ ref arr);"
255+ ) ) ) ;
256+ if raw_output {
257+ diagnostics. push ( PluginDiagnostic :: error (
258+ param. modifiers ( db) . stable_ptr ( db) ,
259+ format ! ( "`{RAW_OUTPUT_ATTR}` functions cannot have `ref` parameters." ) ,
260+ ) ) ;
261+ }
262+ }
263+
264+ let ref_modifier = if is_ref { "ref " } else { "" } ;
265+ let mut_modifier = if is_ref { "mut " } else { "" } ;
266+ (
267+ formatdoc ! (
268+ "
249269 let {mut_modifier}{arg_name} = core::option::OptionTraitImpl::expect(
250270 core::serde::Serde::<{type_name}>::deserialize(ref data),
251271 'Failed to deserialize param #{param_idx}'
252272 );"
253- ) ;
254- arg_definitions. push ( arg_definition) ;
255-
256- if is_ref {
257- ref_appends. push ( RewriteNode :: Text ( format ! (
258- "\n core::serde::Serde::<{type_name}>::serialize(@{arg_name}, ref arr);"
259- ) ) ) ;
260- }
261- }
262- let arg_names_str = arg_names. join ( ", " ) ;
273+ ) ,
274+ format ! ( "{ref_modifier}{arg_name}" ) ,
275+ )
276+ } )
277+ . unzip ( ) ;
278+ (
279+ definitions
280+ . into_iter ( )
281+ . chain ( [ "assert(core::array::SpanTrait::is_empty(data), 'Input too long for \
282+ arguments');"
283+ . to_string ( ) ] )
284+ . join ( "\n " ) ,
285+ RewriteNode :: Text ( arg_names. join ( ", " ) ) ,
286+ )
287+ } ;
263288
264289 let ret_ty = sig. ret_ty ( db) ;
265290 let ( let_res, append_res, return_ty_is_felt252_span, ret_type_ptr) = match & ret_ty {
@@ -292,22 +317,23 @@ fn generate_entry_point_wrapper<'db>(
292317 }
293318
294319 let contract_state_arg = if is_snapshot { "@contract_state" } else { "ref contract_state" } ;
295- let output_handling_string = if raw_output {
296- format ! ( "$wrapped_function_path$({contract_state_arg}, {arg_names_str} )" )
320+ let call_and_output_handling_string = if raw_output {
321+ format ! ( "$wrapped_function_path$({contract_state_arg}, $args_for_wrapped_function$ )" )
297322 } else {
298323 formatdoc ! { "
299- {let_res}$wrapped_function_path$({contract_state_arg}, {arg_names_str} );
324+ {let_res}$wrapped_function_path$({contract_state_arg}, $args_for_wrapped_function$ );
300325 let mut arr = ArrayTrait::new();
301326 // References.$ref_appends$
302327 // Result.{append_res}
303328 core::array::ArrayTrait::span(@arr)"
304329 }
305330 } ;
306331
307- let output_handling = RewriteNode :: interpolate_patched (
308- & output_handling_string ,
332+ let call_and_output_handling = RewriteNode :: interpolate_patched (
333+ & call_and_output_handling_string ,
309334 & [
310335 ( "wrapped_function_path" . to_string ( ) , wrapped_function_path) ,
336+ ( "args_for_wrapped_function" . to_string ( ) , args_for_wrapped_function) ,
311337 ( "ref_appends" . to_string ( ) , RewriteNode :: new_modified ( ref_appends) ) ,
312338 ]
313339 . into ( ) ,
@@ -317,7 +343,6 @@ fn generate_entry_point_wrapper<'db>(
317343 IMPLICIT_PRECEDENCE . iter( ) . join( ", " )
318344 } ) ) ;
319345
320- let arg_definitions = RewriteNode :: Text ( arg_definitions. join ( "\n " ) ) ;
321346 Ok ( RewriteNode :: interpolate_patched (
322347 & formatdoc ! { "
323348 #[doc(hidden)]
@@ -328,20 +353,18 @@ fn generate_entry_point_wrapper<'db>(
328353 let Some(_) = core::gas::withdraw_gas() else {{
329354 core::panic_with_felt252('Out of gas');
330355 }};
331- $arg_definitions$
332- assert(core::array::SpanTrait::is_empty(data), 'Input too long for arguments');
356+ {arg_processing}
333357 let Some(_) = core::gas::withdraw_gas_all(core::gas::get_builtin_costs()) else {{
334358 core::panic_with_felt252('Out of gas');
335359 }};
336360 let mut contract_state = {unsafe_new_contract_state_prefix}unsafe_new_contract_state();
337- $output_handling $
361+ $call_and_output_handling $
338362 }}
339363 " } ,
340364 & [
341365 ( "wrapper_function_name" . to_string ( ) , wrapper_function_name) ,
342366 ( "generic_params" . to_string ( ) , generic_params) ,
343- ( "output_handling" . to_string ( ) , output_handling) ,
344- ( "arg_definitions" . to_string ( ) , arg_definitions) ,
367+ ( "call_and_output_handling" . to_string ( ) , call_and_output_handling) ,
345368 ( "implicit_precedence" . to_string ( ) , implicit_precedence) ,
346369 ]
347370 . into ( ) ,
0 commit comments