11use crate :: syntax:: { ast:: App , Context } ;
22use crate :: { analyze:: Analysis , codegen:: bindings:: interrupt_mod, codegen:: util} ;
3+
34use proc_macro2:: TokenStream as TokenStream2 ;
45use quote:: quote;
56
@@ -112,37 +113,7 @@ pub fn codegen(ctxt: Context, app: &App, analysis: &Analysis) -> TokenStream2 {
112113 let internal_context_name = util:: internal_task_ident ( name, "Context" ) ;
113114 let exec_name = util:: internal_task_ident ( name, "EXEC" ) ;
114115
115- items. push ( quote ! (
116- #( #cfgs) *
117- /// Execution context
118- #[ allow( non_snake_case) ]
119- #[ allow( non_camel_case_types) ]
120- pub struct #internal_context_name<' a> {
121- #[ doc( hidden) ]
122- __rtic_internal_p: :: core:: marker:: PhantomData <& ' a ( ) >,
123- #( #fields, ) *
124- }
125-
126- #( #cfgs) *
127- impl <' a> #internal_context_name<' a> {
128- #[ inline( always) ]
129- #[ allow( missing_docs) ]
130- pub unsafe fn new( #core) -> Self {
131- #internal_context_name {
132- __rtic_internal_p: :: core:: marker:: PhantomData ,
133- #( #values, ) *
134- }
135- }
136- }
137- ) ) ;
138-
139- module_items. push ( quote ! (
140- #( #cfgs) *
141- #[ doc( inline) ]
142- pub use super :: #internal_context_name as Context ;
143- ) ) ;
144-
145- if let Context :: SoftwareTask ( ..) = ctxt {
116+ if let Context :: SoftwareTask ( t) = ctxt {
146117 let spawnee = & app. software_tasks [ name] ;
147118 let priority = spawnee. args . priority ;
148119 let cfgs = & spawnee. cfgs ;
@@ -158,18 +129,21 @@ pub fn codegen(ctxt: Context, app: &App, analysis: &Analysis) -> TokenStream2 {
158129 } ;
159130
160131 let internal_spawn_ident = util:: internal_task_ident ( name, "spawn" ) ;
132+ let internal_spawn_helper_ident = util:: internal_task_ident ( name, "spawn_helper" ) ;
161133 let internal_waker_ident = util:: internal_task_ident ( name, "waker" ) ;
162134 let from_ptr_n_args = util:: from_ptr_n_args_ident ( spawnee. inputs . len ( ) ) ;
163- let ( input_args, input_tupled, input_untupled, input_ty) =
135+ let ( generic_input_args , input_args, input_tupled, input_untupled, input_ty) =
164136 util:: regroup_inputs ( & spawnee. inputs ) ;
165137
166138 // Spawn caller
167139 items. push ( quote ! (
168140 #( #cfgs) *
169- /// Spawns the task directly
141+ /// Spawns the task without checking if the spawner and spawnee are the same priority
142+ ///
143+ /// SAFETY: The caller needs to check that the spawner and spawnee are the same priority
170144 #[ allow( non_snake_case) ]
171145 #[ doc( hidden) ]
172- pub fn #internal_spawn_ident ( #( #input_args, ) * ) -> :: core:: result:: Result <( ) , #input_ty> {
146+ pub unsafe fn #internal_spawn_helper_ident ( #( #input_args, ) * ) -> :: core:: result:: Result <( ) , #input_ty> {
173147 // SAFETY: If `try_allocate` succeeds one must call `spawn`, which we do.
174148 unsafe {
175149 let exec = rtic:: export:: executor:: AsyncTaskExecutor :: #from_ptr_n_args( #name, & #exec_name) ;
@@ -183,6 +157,14 @@ pub fn codegen(ctxt: Context, app: &App, analysis: &Analysis) -> TokenStream2 {
183157 }
184158 }
185159 }
160+
161+ /// Spawns the task directly
162+ #[ allow( non_snake_case) ]
163+ #[ doc( hidden) ]
164+ pub fn #internal_spawn_ident( #( #generic_input_args, ) * ) -> :: core:: result:: Result <( ) , #input_ty> {
165+ // SAFETY: The generic args require Send + Sync
166+ unsafe { #internal_spawn_helper_ident( #( #input_untupled. to( ) ) , * ) }
167+ }
186168 ) ) ;
187169
188170 // Waker
@@ -204,11 +186,63 @@ pub fn codegen(ctxt: Context, app: &App, analysis: &Analysis) -> TokenStream2 {
204186 }
205187 ) ) ;
206188
207- module_items. push ( quote ! (
189+ module_items. push ( quote ! {
208190 #( #cfgs) *
209191 #[ doc( inline) ]
210192 pub use super :: #internal_spawn_ident as spawn;
211- ) ) ;
193+ } ) ;
194+
195+ let tasks_on_same_executor: Vec < _ > = app
196+ . software_tasks
197+ . iter ( )
198+ . filter ( |( _, t) | t. args . priority == priority)
199+ . collect ( ) ;
200+
201+ if !tasks_on_same_executor. is_empty ( ) {
202+ let local_spawner = util:: internal_task_ident ( t, "LocalSpawner" ) ;
203+ fields. push ( quote ! {
204+ /// Used to spawn tasks on the same executor
205+ ///
206+ /// This is useful for tasks that take args which are !Send/!Sync.
207+ pub local_spawner: #local_spawner
208+ } ) ;
209+ let tasks = tasks_on_same_executor
210+ . iter ( )
211+ . map ( |( ident, task) | {
212+ // Copied mostly from software_tasks.rs
213+ let internal_spawn_ident = util:: internal_task_ident ( ident, "spawn_helper" ) ;
214+ let attrs = & task. attrs ;
215+ let cfgs = & task. cfgs ;
216+ let generics = if task. is_bottom {
217+ quote ! ( )
218+ } else {
219+ quote ! ( <' a>)
220+ } ;
221+
222+ let ( _generic_input_args, input_args, _input_tupled, input_untupled, input_ty) = util:: regroup_inputs ( & task. inputs ) ;
223+ quote ! {
224+ #( #attrs) *
225+ #( #cfgs) *
226+ #[ allow( non_snake_case) ]
227+ pub ( super ) fn #ident #generics( & self #( , #input_args) * ) -> :: core:: result:: Result <( ) , #input_ty> {
228+ // SAFETY: This is safe to call since this can only be called
229+ // from the same executor
230+ unsafe { #internal_spawn_ident( #( #input_untupled, ) * ) }
231+ }
232+ }
233+ } )
234+ . collect :: < Vec < _ > > ( ) ;
235+ values. push ( quote ! ( local_spawner: #local_spawner { _p: core:: marker:: PhantomData } ) ) ;
236+ items. push ( quote ! {
237+ struct #local_spawner {
238+ _p: core:: marker:: PhantomData <* mut ( ) >,
239+ }
240+
241+ impl #local_spawner {
242+ #( #tasks) *
243+ }
244+ } ) ;
245+ }
212246
213247 module_items. push ( quote ! (
214248 #( #cfgs) *
@@ -217,6 +251,36 @@ pub fn codegen(ctxt: Context, app: &App, analysis: &Analysis) -> TokenStream2 {
217251 ) ) ;
218252 }
219253
254+ items. push ( quote ! (
255+ #( #cfgs) *
256+ /// Execution context
257+ #[ allow( non_snake_case) ]
258+ #[ allow( non_camel_case_types) ]
259+ pub struct #internal_context_name<' a> {
260+ #[ doc( hidden) ]
261+ __rtic_internal_p: :: core:: marker:: PhantomData <& ' a ( ) >,
262+ #( #fields, ) *
263+ }
264+
265+ #( #cfgs) *
266+ impl <' a> #internal_context_name<' a> {
267+ #[ inline( always) ]
268+ #[ allow( missing_docs) ]
269+ pub unsafe fn new( #core) -> Self {
270+ #internal_context_name {
271+ __rtic_internal_p: :: core:: marker:: PhantomData ,
272+ #( #values, ) *
273+ }
274+ }
275+ }
276+ ) ) ;
277+
278+ module_items. push ( quote ! (
279+ #( #cfgs) *
280+ #[ doc( inline) ]
281+ pub use super :: #internal_context_name as Context ;
282+ ) ) ;
283+
220284 if items. is_empty ( ) {
221285 quote ! ( )
222286 } else {
0 commit comments