@@ -127,6 +127,7 @@ def __init__(
127127 cast_types : bool = True ,
128128 max_async_tasks : int = 30 ,
129129 propagate_exceptions : bool = True ,
130+ await_inplace : bool = False ,
130131 ) -> None :
131132 super ().__init__ ()
132133 self .result_backend = InmemoryResultBackend (
@@ -140,6 +141,7 @@ def __init__(
140141 max_async_tasks = max_async_tasks ,
141142 propagate_exceptions = propagate_exceptions ,
142143 )
144+ self .await_inplace = await_inplace
143145 self ._running_tasks : "Set[asyncio.Task[Any]]" = set ()
144146
145147 async def kick (self , message : BrokerMessage ) -> None :
@@ -156,7 +158,12 @@ async def kick(self, message: BrokerMessage) -> None:
156158 if target_task is None :
157159 raise TaskiqError ("Unknown task." )
158160
159- task = asyncio .create_task (self .receiver .callback (message = message .message ))
161+ receiver_cb = self .receiver .callback (message = message .message )
162+ if self .await_inplace :
163+ await receiver_cb
164+ return
165+
166+ task = asyncio .create_task (receiver_cb )
160167 self ._running_tasks .add (task )
161168 task .add_done_callback (self ._running_tasks .discard )
162169
@@ -171,6 +178,17 @@ def listen(self) -> AsyncGenerator[bytes, None]:
171178 """
172179 raise RuntimeError ("Inmemory brokers cannot listen." )
173180
181+ async def wait_all (self ) -> None :
182+ """
183+ Wait for all currently running tasks to complete.
184+
185+ Useful when used in testing and you need to await all sent tasks
186+ before asserting results.
187+ """
188+ to_await = list (self ._running_tasks )
189+ for task in to_await :
190+ await task
191+
174192 async def startup (self ) -> None :
175193 """Runs startup events for client and worker side."""
176194 for event in (TaskiqEvents .CLIENT_STARTUP , TaskiqEvents .WORKER_STARTUP ):
0 commit comments