@@ -91,7 +91,7 @@ def test_graphband_sequential_success(tmp_path):
9191 with Session (pbar ._engine ) as session :
9292 workers = session .query (WorkerEntry ).all ()
9393 assert len (workers ) == 1
94- assert workers [0 ].status == WorkerStatus .OFFLINE
94+ assert workers [0 ].status == WorkerStatus .IDLE
9595 tasks = session .query (TaskEntry ).all ()
9696 assert len (tasks ) == 10
9797 for id , task in enumerate (tasks ):
@@ -102,6 +102,15 @@ def test_graphband_sequential_success(tmp_path):
102102 # if we no iterate again, we yield nothing
103103 assert list (pbar ) == []
104104
105+ # Test that worker goes offline when garbage collected
106+ engine = pbar ._engine
107+ del pbar
108+
109+ with Session (engine ) as session :
110+ workers = session .query (WorkerEntry ).all ()
111+ assert len (workers ) == 1
112+ assert workers [0 ].status == WorkerStatus .OFFLINE
113+
105114 pbar = Graphband (
106115 sequential_task (),
107116 db = f"sqlite:///{ tmp_path } /graphband.sqlite" ,
@@ -732,3 +741,78 @@ def test_has_more_jobs_with_killed_workers(tmp_path):
732741 assert retries_worker .has_more_jobs is True
733742 assert len (list (retries_worker )) == 1
734743 assert retries_worker .has_more_jobs is False
744+
745+
746+ def test_resume_worker (tmp_path ):
747+ lock_path = f"{ tmp_path } /graphband.lock"
748+ db_path = f"sqlite:///{ tmp_path } /graphband.sqlite"
749+
750+ worker = Graphband (
751+ sequential_task (),
752+ db = db_path ,
753+ lock = Lock (lock_path ),
754+ heartbeat_timeout = 2 ,
755+ heartbeat_interval = 1 ,
756+ identifier = "worker" ,
757+ )
758+ engine = create_engine (worker .db )
759+ length = 0
760+
761+ with Session (engine ) as session :
762+ worker_entry = session .get (WorkerEntry , "worker" )
763+ assert worker_entry is not None
764+ assert worker_entry .status == WorkerStatus .IDLE
765+
766+ for item in worker :
767+ with Session (engine ) as session :
768+ worker_entry = session .get (WorkerEntry , "worker" )
769+ assert worker_entry is not None
770+ assert worker_entry .status == WorkerStatus .BUSY
771+ length += 1
772+ if item .id == "task_5" :
773+ break
774+ assert length == 6
775+
776+ with Session (engine ) as session :
777+ worker_entry = session .get (WorkerEntry , "worker" )
778+ assert worker_entry is not None
779+ assert worker_entry .status == WorkerStatus .IDLE
780+
781+ for item in worker :
782+ with Session (engine ) as session :
783+ worker_entry = session .get (WorkerEntry , "worker" )
784+ assert worker_entry is not None
785+ assert worker_entry .status == WorkerStatus .BUSY
786+ length += 1
787+
788+ assert length == 10
789+
790+ with Session (engine ) as session :
791+ worker_entry = session .get (WorkerEntry , "worker" )
792+ assert worker_entry is not None
793+ assert worker_entry .status == WorkerStatus .IDLE
794+
795+ del worker
796+
797+ with Session (engine ) as session :
798+ worker_entry = session .get (WorkerEntry , "worker" )
799+ assert worker_entry is not None
800+ assert worker_entry .status == WorkerStatus .OFFLINE
801+
802+ proc = multiprocessing .Process (
803+ target = task_worker ,
804+ args = (sequential_task , lock_path , db_path , tmp_path / "test.txt" , 0.1 ),
805+ kwargs = {
806+ "heartbeat_timeout" : 1 ,
807+ "heartbeat_interval" : 0.5 ,
808+ "max_killed_retries" : 0 , # No retries allowed for killed tasks
809+ "identifier" : "killed-worker" ,
810+ },
811+ )
812+ proc .start ()
813+ proc .join (timeout = 5 )
814+
815+ with Session (engine ) as session :
816+ worker_entry = session .get (WorkerEntry , "killed-worker" )
817+ assert worker_entry is not None
818+ assert worker_entry .status == WorkerStatus .OFFLINE
0 commit comments