@@ -7,7 +7,7 @@ use pyo3::types::{PyBytes, PyTuple};
77use std:: collections:: HashMap ;
88use std:: collections:: HashSet ;
99use std:: marker:: PhantomData ;
10- use std:: sync:: Arc ;
10+ use std:: sync:: { Arc , OnceLock } ;
1111use std:: time:: Duration ;
1212use temporal_sdk_core:: api:: errors:: { PollActivityError , PollWfError } ;
1313use temporal_sdk_core:: replay:: { HistoryForReplay , ReplayWorkerInput } ;
@@ -64,6 +64,18 @@ pub struct TunerHolder {
6464 local_activity_slot_supplier : SlotSupplier ,
6565}
6666
67+ // pub fn set_task_locals_on_tuner<'a>(py: Python<'a>, tuner: &TunerHolder) -> PyResult<()> {
68+ // // TODO: All suppliers
69+ // if let SlotSupplier::Custom(ref cs) = tuner.workflow_slot_supplier {
70+ // Python::with_gil(|py| {
71+ // let py_obj = cs.inner.as_ref(py);
72+ // py_obj.call_method0("set_task_locals")?;
73+ // Ok(())
74+ // })?;
75+ // };
76+ // Ok(())
77+ // }
78+
6779#[ derive( FromPyObject ) ]
6880pub enum SlotSupplier {
6981 FixedSize ( FixedSizeSlotSupplier ) ,
@@ -190,17 +202,60 @@ impl CustomSlotSupplier {
190202 }
191203}
192204
205+ // Shouldn't really need this callback nonsense, it should be possible to do this from the pyo3
206+ // asyncio library, but we'd have to vendor the whole thing to make the right improvements. When
207+ // pyo3 is upgraded and we are using
208+
209+ #[ pyclass]
210+ struct CreatedTaskForSlotCallback {
211+ stored_task : Arc < OnceLock < PyObject > > ,
212+ }
213+
214+ #[ pymethods]
215+ impl CreatedTaskForSlotCallback {
216+ fn __call__ ( & self , task : PyObject ) -> PyResult < ( ) > {
217+ self . stored_task . set ( task) . expect ( "must only be set once" ) ;
218+ Ok ( ( ) )
219+ }
220+ }
221+
222+ struct TaskCanceller {
223+ stored_task : Arc < OnceLock < PyObject > > ,
224+ }
225+
226+ impl TaskCanceller {
227+ fn new ( stored_task : Arc < OnceLock < PyObject > > ) -> Self {
228+ TaskCanceller { stored_task }
229+ }
230+ }
231+
232+ impl Drop for TaskCanceller {
233+ fn drop ( & mut self ) {
234+ if let Some ( task) = self . stored_task . get ( ) {
235+ Python :: with_gil ( |py| {
236+ task. call_method0 ( py, "cancel" )
237+ . expect ( "Failed to cancel task" ) ;
238+ } ) ;
239+ }
240+ }
241+ }
242+
193243#[ async_trait:: async_trait]
194244impl < SK : SlotKind + Send + Sync > SlotSupplierTrait for CustomSlotSupplierOfType < SK > {
195245 type SlotKind = SK ;
196246
197247 async fn reserve_slot ( & self , ctx : & dyn SlotReservationContext ) -> SlotSupplierPermit {
198248 loop {
249+ let stored_task = Arc :: new ( OnceLock :: new ( ) ) ;
250+ let _task_canceller = TaskCanceller :: new ( stored_task. clone ( ) ) ;
199251 let pypermit = match Python :: with_gil ( |py| {
200252 let py_obj = self . inner . as_ref ( py) ;
201253 let called = py_obj. call_method1 (
202254 "reserve_slot" ,
203- ( SlotReserveCtx :: from_ctx ( Self :: SlotKind :: kind ( ) , ctx) , ) ,
255+ (
256+ SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) ,
257+ CreatedTaskForSlotCallback { stored_task } ,
258+ ) ,
204259 ) ?;
205260 runtime:: THREAD_TASK_LOCAL
206261 . with ( |tl| pyo3_asyncio:: into_future_with_locals ( tl. get ( ) . unwrap ( ) , called) )
@@ -232,7 +287,7 @@ impl<SK: SlotKind + Send + Sync> SlotSupplierTrait for CustomSlotSupplierOfType<
232287 let py_obj = self . inner . as_ref ( py) ;
233288 let pa = py_obj. call_method1 (
234289 "try_reserve_slot" ,
235- ( SlotReserveCtx :: from_ctx ( Self :: SlotKind :: kind ( ) , ctx) , ) ,
290+ ( SlotReserveCtx :: from_ctx ( SK :: kind ( ) , ctx) , ) ,
236291 ) ?;
237292
238293 if pa. is_none ( ) {
@@ -362,6 +417,8 @@ pub fn new_replay_worker<'a>(
362417impl WorkerRef {
363418 fn validate < ' p > ( & self , py : Python < ' p > ) -> PyResult < & ' p PyAny > {
364419 let worker = self . worker . as_ref ( ) . unwrap ( ) . clone ( ) ;
420+ // Set custom slot supplier task locals so they can run futures
421+ // match worker.get_config().tuner {}
365422 self . runtime . future_into_py ( py, async move {
366423 worker
367424 . validate ( )
0 commit comments