@@ -7,7 +7,7 @@ use pyo3::types::{PyBytes, PyTuple};
77use std:: collections:: HashMap ;
88use std:: collections:: HashSet ;
99use std:: marker:: PhantomData ;
10- use std:: sync:: Arc ;
10+ use std:: sync:: { Arc , OnceLock } ;
1111use std:: time:: Duration ;
1212use temporal_sdk_core:: api:: errors:: { PollActivityError , PollWfError } ;
1313use temporal_sdk_core:: replay:: { HistoryForReplay , ReplayWorkerInput } ;
@@ -202,16 +202,58 @@ impl CustomSlotSupplier {
202202 }
203203}
204204
205+ #[ pyclass]
206+ struct CreatedTaskForSlotCallback {
207+ stored_task : Arc < OnceLock < PyObject > > ,
208+ }
209+
210+ #[ pymethods]
211+ impl CreatedTaskForSlotCallback {
212+ fn __call__ ( & self , task : PyObject ) -> PyResult < ( ) > {
213+ self . stored_task . set ( task) . expect ( "must only be set once" ) ;
214+ Ok ( ( ) )
215+ }
216+ }
217+
218+ struct TaskCanceller {
219+ stored_task : Arc < OnceLock < PyObject > > ,
220+ }
221+
222+ impl TaskCanceller {
223+ fn new ( stored_task : Arc < OnceLock < PyObject > > ) -> Self {
224+ TaskCanceller { stored_task }
225+ }
226+ }
227+
228+ impl Drop for TaskCanceller {
229+ fn drop ( & mut self ) {
230+ if let Some ( task) = self . stored_task . get ( ) {
231+ Python :: with_gil ( |py| {
232+ task. call_method0 ( py, "cancel" )
233+ . expect ( "Failed to cancel task" ) ;
234+ } ) ;
235+ }
236+ }
237+ }
238+
205239#[ async_trait:: async_trait]
206240impl < SK : SlotKind + Send + Sync > SlotSupplierTrait for CustomSlotSupplierOfType < SK > {
207241 type SlotKind = SK ;
208242
209243 async fn reserve_slot ( & self , ctx : & dyn SlotReservationContext ) -> SlotSupplierPermit {
244+ dbg ! ( "Invoking reserve first time" ) ;
210245 loop {
246+ let stored_task = Arc :: new ( OnceLock :: new ( ) ) ;
247+ let _task_canceller = TaskCanceller :: new ( stored_task. clone ( ) ) ;
211248 let pypermit = match Python :: with_gil ( |py| {
212249 let py_obj = self . inner . as_ref ( py) ;
213- let called = py_obj
214- . call_method1 ( "reserve_slot" , ( SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) , ) ) ?;
250+ let called = py_obj. call_method1 (
251+ "reserve_slot" ,
252+ (
253+ SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) ,
254+ CreatedTaskForSlotCallback { stored_task } ,
255+ ) ,
256+ ) ?;
215257 runtime:: THREAD_TASK_LOCAL
216258 . with ( |tl| pyo3_asyncio:: into_future_with_locals ( tl. get ( ) . unwrap ( ) , called) )
217259 } ) {
0 commit comments