11use anyhow:: Context ;
22use prost:: Message ;
3- use pyo3:: exceptions:: { PyException , PyRuntimeError , PyValueError } ;
3+ use pyo3:: exceptions:: { PyException , PyRuntimeError , PyTypeError , PyValueError } ;
44use pyo3:: prelude:: * ;
55use pyo3:: types:: { PyBytes , PyTuple } ;
6+ use pyo3_asyncio:: generic:: ContextExt ;
67use std:: collections:: HashMap ;
78use std:: collections:: HashSet ;
9+ use std:: marker:: PhantomData ;
810use std:: sync:: Arc ;
911use std:: time:: Duration ;
1012use temporal_sdk_core:: api:: errors:: { PollActivityError , PollWfError } ;
1113use temporal_sdk_core:: replay:: { HistoryForReplay , ReplayWorkerInput } ;
1214use temporal_sdk_core_api:: errors:: WorkflowErrorType ;
13- use temporal_sdk_core_api:: worker:: SlotKind ;
15+ use temporal_sdk_core_api:: worker:: {
16+ SlotKind , SlotMarkUsedContext , SlotReleaseContext , SlotReservationContext ,
17+ SlotSupplier as SlotSupplierTrait , SlotSupplierPermit ,
18+ } ;
1419use temporal_sdk_core_api:: Worker ;
1520use temporal_sdk_core_protos:: coresdk:: workflow_completion:: WorkflowActivationCompletion ;
1621use temporal_sdk_core_protos:: coresdk:: { ActivityHeartbeat , ActivityTaskCompletion } ;
@@ -20,6 +25,7 @@ use tokio_stream::wrappers::ReceiverStream;
2025
2126use crate :: client;
2227use crate :: runtime;
28+ use crate :: runtime:: { TokioRuntime , THREAD_TASK_LOCAL } ;
2329
2430pyo3:: create_exception!( temporal_sdk_bridge, PollShutdownError , PyException ) ;
2531
@@ -63,6 +69,7 @@ pub struct TunerHolder {
6369pub enum SlotSupplier {
6470 FixedSize ( FixedSizeSlotSupplier ) ,
6571 ResourceBased ( ResourceBasedSlotSupplier ) ,
72+ Custom ( CustomSlotSupplier ) ,
6673}
6774
6875#[ derive( FromPyObject ) ]
@@ -79,6 +86,125 @@ pub struct ResourceBasedSlotSupplier {
7986 tuner_config : ResourceBasedTunerConfig ,
8087}
8188
89+ #[ pyclass]
90+ pub struct SlotReserveCtx {
91+ slot_type : String , // TODO: Real type
92+ task_queue : String ,
93+ worker_identity : String ,
94+ worker_build_id : String ,
95+ is_sticky : bool ,
96+ }
97+
98+ impl SlotReserveCtx {
99+ fn from_ctx ( slot_type : String , ctx : & dyn SlotReservationContext ) -> Self {
100+ SlotReserveCtx {
101+ slot_type,
102+ task_queue : ctx. task_queue ( ) . to_string ( ) ,
103+ worker_identity : ctx. worker_identity ( ) . to_string ( ) ,
104+ worker_build_id : ctx. worker_build_id ( ) . to_string ( ) ,
105+ is_sticky : ctx. is_sticky ( ) ,
106+ }
107+ }
108+ }
109+
110+ #[ pyclass]
111+ pub struct SlotMarkUsedCtx { }
112+
113+ #[ pyclass]
114+ pub struct SlotReleaseCtx { }
115+
116+ #[ pyclass]
117+ #[ derive( Clone ) ]
118+ pub struct CustomSlotSupplier {
119+ inner : PyObject ,
120+ }
121+
122+ struct CustomSlotSupplierOfType < SK : SlotKind > {
123+ inner : PyObject ,
124+ _phantom : PhantomData < SK > ,
125+ }
126+
127+ #[ pymethods]
128+ impl CustomSlotSupplier {
129+ #[ new]
130+ fn new ( inner : PyObject ) -> Self {
131+ CustomSlotSupplier { inner }
132+ }
133+ }
134+
135+ impl < SK : SlotKind > CustomSlotSupplierOfType < SK > {
136+ fn call_method < P : IntoPy < PyObject > , F : FnOnce ( Python < ' _ > , & PyAny ) -> FR , FR > (
137+ & self ,
138+ method_name : & str ,
139+ arg : P ,
140+ post_closure : F ,
141+ ) -> FR {
142+ Python :: with_gil ( |py| {
143+ let py_obj = self . inner . as_ref ( py) ;
144+ let method = py_obj
145+ . getattr ( method_name)
146+ . map_err ( |_| {
147+ PyTypeError :: new_err ( format ! (
148+ "CustomSlotSupplier must implement '{}' method" ,
149+ method_name
150+ ) )
151+ } )
152+ . expect ( "TODO" ) ;
153+
154+ post_closure ( py, method. call ( ( arg. into_py ( py) , ) , None ) . expect ( "TODO" ) )
155+ } )
156+ }
157+ }
158+
159+ #[ async_trait:: async_trait]
160+ impl < SK : SlotKind + Send + Sync > SlotSupplierTrait for CustomSlotSupplierOfType < SK > {
161+ type SlotKind = SK ;
162+
163+ async fn reserve_slot ( & self , ctx : & dyn SlotReservationContext ) -> SlotSupplierPermit {
164+ dbg ! ( "Trying to reserve slot" ) ;
165+ let pypermit = Python :: with_gil ( |py| {
166+ let py_obj = self . inner . as_ref ( py) ;
167+ let called = py_obj. call_method1 (
168+ "reserve_slot" ,
169+ ( SlotReserveCtx :: from_ctx (
170+ Self :: SlotKind :: kind ( ) . to_string ( ) ,
171+ ctx,
172+ ) , ) ,
173+ ) ?;
174+ THREAD_TASK_LOCAL
175+ . with ( |tl| pyo3_asyncio:: into_future_with_locals ( tl. get ( ) . unwrap ( ) , called) )
176+ } )
177+ . expect ( "TODO" )
178+ . await ;
179+ SlotSupplierPermit :: with_user_data ( pypermit)
180+ }
181+
182+ fn try_reserve_slot ( & self , ctx : & dyn SlotReservationContext ) -> Option < SlotSupplierPermit > {
183+ self . call_method (
184+ "try_reserve_slot" ,
185+ SlotReserveCtx :: from_ctx ( Self :: SlotKind :: kind ( ) . to_string ( ) , ctx) ,
186+ |py, pa| {
187+ if pa. is_none ( ) {
188+ return None ;
189+ }
190+ Some ( SlotSupplierPermit :: with_user_data ( pa. into_py ( py) ) )
191+ } ,
192+ )
193+ }
194+
195+ fn mark_slot_used ( & self , _ctx : & dyn SlotMarkUsedContext < SlotKind = Self :: SlotKind > ) {
196+ self . call_method ( "mark_slot_used" , SlotMarkUsedCtx { } , |_, _| ( ) )
197+ }
198+
199+ fn release_slot ( & self , _ctx : & dyn SlotReleaseContext < SlotKind = Self :: SlotKind > ) {
200+ self . call_method ( "release_slot" , SlotReleaseCtx { } , |_, _| ( ) )
201+ }
202+
203+ fn available_slots ( & self ) -> Option < usize > {
204+ None
205+ }
206+ }
207+
82208#[ derive( FromPyObject , Clone , Copy , PartialEq ) ]
83209pub struct ResourceBasedTunerConfig {
84210 target_memory_usage : f64 ,
@@ -369,7 +495,9 @@ impl TryFrom<TunerHolder> for temporal_sdk_core::TunerHolder {
369495 }
370496}
371497
372- impl < SK : SlotKind > TryFrom < SlotSupplier > for temporal_sdk_core:: SlotSupplierOptions < SK > {
498+ impl < SK : SlotKind + Send + Sync + ' static > TryFrom < SlotSupplier >
499+ for temporal_sdk_core:: SlotSupplierOptions < SK >
500+ {
373501 type Error = PyErr ;
374502
375503 fn try_from ( supplier : SlotSupplier ) -> PyResult < temporal_sdk_core:: SlotSupplierOptions < SK > > {
@@ -386,6 +514,12 @@ impl<SK: SlotKind> TryFrom<SlotSupplier> for temporal_sdk_core::SlotSupplierOpti
386514 ) ,
387515 )
388516 }
517+ SlotSupplier :: Custom ( cs) => temporal_sdk_core:: SlotSupplierOptions :: Custom ( Arc :: new (
518+ CustomSlotSupplierOfType :: < SK > {
519+ inner : cs. inner ,
520+ _phantom : PhantomData ,
521+ } ,
522+ ) ) ,
389523 } )
390524 }
391525}
0 commit comments