@@ -7,15 +7,17 @@ use futures_util::FutureExt;
77use std:: {
88 cell:: Cell ,
99 sync:: {
10- Arc ,
10+ Arc , Mutex ,
1111 atomic:: { AtomicBool , Ordering :: Relaxed } ,
1212 } ,
1313 time:: Duration ,
1414} ;
1515use temporal_client:: WorkflowOptions ;
16- use temporal_sdk:: { ActivityOptions , WfContext , interceptors:: WorkerInterceptor } ;
16+ use temporal_sdk:: {
17+ ActivityOptions , LocalActivityOptions , WfContext , interceptors:: WorkerInterceptor ,
18+ } ;
1719use temporal_sdk_core:: {
18- CoreRuntime , ResourceBasedTuner , ResourceSlotOptions , init_worker,
20+ CoreRuntime , ResourceBasedTuner , ResourceSlotOptions , TunerBuilder , init_worker,
1921 test_help:: {
2022 FakeWfResponses , MockPollCfg , ResponseType , TEST_Q , build_mock_pollers,
2123 drain_pollers_and_shutdown, hist_to_poll_resp, mock_worker, mock_worker_client,
@@ -24,7 +26,11 @@ use temporal_sdk_core::{
2426use temporal_sdk_core_api:: {
2527 Worker ,
2628 errors:: WorkerValidationError ,
27- worker:: { PollerBehavior , WorkerConfigBuilder , WorkerVersioningStrategy } ,
29+ worker:: {
30+ ActivitySlotKind , LocalActivitySlotKind , PollerBehavior , SlotInfo , SlotInfoTrait ,
31+ SlotMarkUsedContext , SlotReleaseContext , SlotReservationContext , SlotSupplier ,
32+ SlotSupplierPermit , WorkerConfigBuilder , WorkerVersioningStrategy , WorkflowSlotKind ,
33+ } ,
2834} ;
2935use temporal_sdk_core_protos:: {
3036 DEFAULT_WORKFLOW_TYPE , TestHistoryBuilder , canned_histories,
@@ -571,3 +577,282 @@ async fn sets_build_id_from_wft_complete() {
571577 . unwrap ( ) ;
572578 worker. run_until_done ( ) . await . unwrap ( ) ;
573579}
580+
581+ #[ derive( Debug , Clone ) ]
582+ enum SlotEvent {
583+ ReserveSlot {
584+ slot_type : & ' static str ,
585+ } ,
586+ TryReserveSlot {
587+ slot_type : & ' static str ,
588+ } ,
589+ MarkSlotUsed {
590+ slot_type : & ' static str ,
591+ is_sticky : bool ,
592+ workflow_type : Option < String > ,
593+ activity_type : Option < String > ,
594+ } ,
595+ ReleaseSlot {
596+ slot_type : & ' static str ,
597+ } ,
598+ }
599+
600+ struct TrackingSlotSupplier < SK > {
601+ events : Arc < Mutex < Vec < SlotEvent > > > ,
602+ slot_type : & ' static str ,
603+ _phantom : std:: marker:: PhantomData < SK > ,
604+ }
605+
606+ impl < SK > TrackingSlotSupplier < SK > {
607+ fn new ( slot_type : & ' static str ) -> Self {
608+ Self {
609+ events : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
610+ slot_type,
611+ _phantom : std:: marker:: PhantomData ,
612+ }
613+ }
614+
615+ fn get_events ( & self ) -> Vec < SlotEvent > {
616+ self . events . lock ( ) . unwrap ( ) . clone ( )
617+ }
618+
619+ fn add_event ( & self , event : SlotEvent ) {
620+ self . events . lock ( ) . unwrap ( ) . push ( event) ;
621+ }
622+
623+ fn extract_slot_info ( info : & dyn SlotInfoTrait ) -> ( bool , Option < String > , Option < String > ) {
624+ match info. downcast ( ) {
625+ SlotInfo :: Workflow ( w) => ( w. is_sticky , Some ( w. workflow_type . clone ( ) ) , None ) ,
626+ SlotInfo :: Activity ( a) => ( false , None , Some ( a. activity_type . clone ( ) ) ) ,
627+ SlotInfo :: LocalActivity ( a) => ( false , None , Some ( a. activity_type . clone ( ) ) ) ,
628+ SlotInfo :: Nexus ( _) => ( false , None , None ) ,
629+ }
630+ }
631+ }
632+
633+ #[ async_trait:: async_trait]
634+ impl < SK > SlotSupplier for TrackingSlotSupplier < SK >
635+ where
636+ SK : temporal_sdk_core_api:: worker:: SlotKind + Send + Sync ,
637+ SK :: Info : SlotInfoTrait ,
638+ {
639+ type SlotKind = SK ;
640+
641+ async fn reserve_slot ( & self , _ctx : & dyn SlotReservationContext ) -> SlotSupplierPermit {
642+ self . add_event ( SlotEvent :: ReserveSlot {
643+ slot_type : self . slot_type ,
644+ } ) ;
645+ SlotSupplierPermit :: with_user_data ( ( ) )
646+ }
647+
648+ fn try_reserve_slot ( & self , _ctx : & dyn SlotReservationContext ) -> Option < SlotSupplierPermit > {
649+ self . add_event ( SlotEvent :: TryReserveSlot {
650+ slot_type : self . slot_type ,
651+ } ) ;
652+ Some ( SlotSupplierPermit :: with_user_data ( ( ) ) )
653+ }
654+
655+ fn mark_slot_used ( & self , ctx : & dyn SlotMarkUsedContext < SlotKind = Self :: SlotKind > ) {
656+ let ( is_sticky, workflow_type, activity_type) = Self :: extract_slot_info ( ctx. info ( ) ) ;
657+ self . add_event ( SlotEvent :: MarkSlotUsed {
658+ slot_type : self . slot_type ,
659+ is_sticky,
660+ workflow_type,
661+ activity_type,
662+ } ) ;
663+ }
664+
665+ fn release_slot ( & self , _ctx : & dyn SlotReleaseContext < SlotKind = Self :: SlotKind > ) {
666+ self . add_event ( SlotEvent :: ReleaseSlot {
667+ slot_type : self . slot_type ,
668+ } ) ;
669+ }
670+ }
671+
672+ #[ tokio:: test]
673+ async fn test_custom_slot_supplier_simple ( ) {
674+ let wf_supplier = Arc :: new ( TrackingSlotSupplier :: < WorkflowSlotKind > :: new ( "workflow" ) ) ;
675+ let activity_supplier = Arc :: new ( TrackingSlotSupplier :: < ActivitySlotKind > :: new ( "activity" ) ) ;
676+ let local_activity_supplier = Arc :: new ( TrackingSlotSupplier :: < LocalActivitySlotKind > :: new (
677+ "local_activity" ,
678+ ) ) ;
679+
680+ let mut starter = CoreWfStarter :: new ( "test_custom_slot_supplier_simple" ) ;
681+ starter. worker_config . clear_max_outstanding_opts ( ) ;
682+
683+ let mut tb = TunerBuilder :: default ( ) ;
684+ tb. workflow_slot_supplier ( wf_supplier. clone ( ) ) ;
685+ tb. activity_slot_supplier ( activity_supplier. clone ( ) ) ;
686+ tb. local_activity_slot_supplier ( local_activity_supplier. clone ( ) ) ;
687+ starter. worker_config . tuner ( Arc :: new ( tb. build ( ) ) ) ;
688+
689+ let mut worker = starter. worker ( ) . await ;
690+
691+ worker. register_activity (
692+ "SlotSupplierActivity" ,
693+ |_: temporal_sdk:: ActContext , _: ( ) | async move { Ok ( ( ) ) } ,
694+ ) ;
695+ worker. register_wf (
696+ "SlotSupplierWorkflow" . to_owned ( ) ,
697+ |ctx : WfContext | async move {
698+ let _result = ctx
699+ . activity ( ActivityOptions {
700+ activity_type : "SlotSupplierActivity" . to_string ( ) ,
701+ start_to_close_timeout : Some ( Duration :: from_secs ( 10 ) ) ,
702+ ..Default :: default ( )
703+ } )
704+ . await ;
705+ let _result = ctx
706+ . local_activity ( LocalActivityOptions {
707+ activity_type : "SlotSupplierActivity" . to_string ( ) ,
708+ start_to_close_timeout : Some ( Duration :: from_secs ( 10 ) ) ,
709+ ..Default :: default ( )
710+ } )
711+ . await ;
712+ Ok ( ( ) . into ( ) )
713+ } ,
714+ ) ;
715+
716+ worker
717+ . submit_wf (
718+ "test-wf" . to_owned ( ) ,
719+ "SlotSupplierWorkflow" . to_owned ( ) ,
720+ vec ! [ ] ,
721+ Default :: default ( ) ,
722+ )
723+ . await
724+ . unwrap ( ) ;
725+
726+ worker. run_until_done ( ) . await . unwrap ( ) ;
727+
728+ // Collect all events
729+ let wf_events = wf_supplier. get_events ( ) ;
730+ let activity_events = activity_supplier. get_events ( ) ;
731+ let local_activity_events = local_activity_supplier. get_events ( ) ;
732+
733+ // Verify workflow slot events - should have reserve, mark used, and release events
734+ assert ! ( wf_events. iter( ) . any(
735+ |e| matches!( e, SlotEvent :: ReserveSlot { slot_type, .. } if * slot_type == "workflow" )
736+ ) ) ;
737+ assert ! ( wf_events. iter( ) . any(
738+ |e| matches!( e, SlotEvent :: MarkSlotUsed { slot_type, .. } if * slot_type == "workflow" )
739+ ) ) ;
740+ assert ! (
741+ wf_events
742+ . iter( )
743+ . any( |e| matches!( e, SlotEvent :: ReleaseSlot { slot_type } if * slot_type == "workflow" ) )
744+ ) ;
745+
746+ // Verify activity slot events - should have reserve, try_reserve (for eager execution), mark
747+ // used, and release
748+ assert ! ( activity_events. iter( ) . any(
749+ |e| matches!( e, SlotEvent :: ReserveSlot { slot_type, .. } if * slot_type == "activity" )
750+ ) ) ;
751+ assert ! (
752+ activity_events. iter( ) . any(
753+ |e| matches!( e, SlotEvent :: TryReserveSlot { slot_type } if * slot_type == "activity" )
754+ )
755+ ) ;
756+ assert ! ( activity_events. iter( ) . any(
757+ |e| matches!( e, SlotEvent :: MarkSlotUsed { slot_type, .. } if * slot_type == "activity" )
758+ ) ) ;
759+ assert ! (
760+ activity_events
761+ . iter( )
762+ . any( |e| matches!( e, SlotEvent :: ReleaseSlot { slot_type } if * slot_type == "activity" ) )
763+ ) ;
764+
765+ // Verify local activity slot events
766+ assert ! ( local_activity_events. iter( ) . any(
767+ |e| matches!( e, SlotEvent :: ReserveSlot { slot_type, .. } if * slot_type == "local_activity" )
768+ ) ) ;
769+ assert ! ( local_activity_events. iter( ) . any(
770+ |e| matches!( e, SlotEvent :: MarkSlotUsed { slot_type, .. } if * slot_type == "local_activity" )
771+ ) ) ;
772+ assert ! ( local_activity_events. iter( ) . any(
773+ |e| matches!( e, SlotEvent :: ReleaseSlot { slot_type } if * slot_type == "local_activity" )
774+ ) ) ;
775+
776+ assert ! (
777+ wf_events
778+ . iter( )
779+ . any( |e| matches!( e, SlotEvent :: MarkSlotUsed {
780+ slot_type: "workflow" ,
781+ workflow_type: Some ( wf_type) ,
782+ ..
783+ } if wf_type == "SlotSupplierWorkflow" ) )
784+ ) ;
785+ assert ! (
786+ activity_events
787+ . iter( )
788+ . any( |e| matches!( e, SlotEvent :: MarkSlotUsed {
789+ slot_type: "activity" ,
790+ activity_type: Some ( act_type) ,
791+ ..
792+ } if act_type == "SlotSupplierActivity" ) )
793+ ) ;
794+ assert ! (
795+ local_activity_events
796+ . iter( )
797+ . any( |e| matches!( e, SlotEvent :: MarkSlotUsed {
798+ slot_type: "local_activity" ,
799+ activity_type: Some ( act_type) ,
800+ ..
801+ } if act_type == "SlotSupplierActivity" ) )
802+ ) ;
803+ assert ! ( wf_events. iter( ) . any( |e| matches!(
804+ e,
805+ SlotEvent :: MarkSlotUsed {
806+ slot_type: "workflow" ,
807+ is_sticky: false ,
808+ ..
809+ }
810+ ) ) ) ;
811+
812+ // Verify that the number of reserve/try_reserve events matches the number of release events
813+ let total_reserves = wf_events
814+ . iter ( )
815+ . filter ( |e| {
816+ matches ! (
817+ e,
818+ SlotEvent :: ReserveSlot { .. } | SlotEvent :: TryReserveSlot { .. }
819+ )
820+ } )
821+ . count ( )
822+ + activity_events
823+ . iter ( )
824+ . filter ( |e| {
825+ matches ! (
826+ e,
827+ SlotEvent :: ReserveSlot { .. } | SlotEvent :: TryReserveSlot { .. }
828+ )
829+ } )
830+ . count ( )
831+ + local_activity_events
832+ . iter ( )
833+ . filter ( |e| {
834+ matches ! (
835+ e,
836+ SlotEvent :: ReserveSlot { .. } | SlotEvent :: TryReserveSlot { .. }
837+ )
838+ } )
839+ . count ( ) ;
840+
841+ let total_releases = wf_events
842+ . iter ( )
843+ . filter ( |e| matches ! ( e, SlotEvent :: ReleaseSlot { .. } ) )
844+ . count ( )
845+ + activity_events
846+ . iter ( )
847+ . filter ( |e| matches ! ( e, SlotEvent :: ReleaseSlot { .. } ) )
848+ . count ( )
849+ + local_activity_events
850+ . iter ( )
851+ . filter ( |e| matches ! ( e, SlotEvent :: ReleaseSlot { .. } ) )
852+ . count ( ) ;
853+
854+ assert_eq ! (
855+ total_reserves, total_releases,
856+ "Number of reserves should equal number of releases"
857+ ) ;
858+ }
0 commit comments