Skip to content

Commit d6f39a8

Browse files
move to idle instead of offline (#40)
* move to idle instead of offline * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 75daa56 commit d6f39a8

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

laufband/graphband.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,10 @@ def _register_worker(self):
218218
session.commit()
219219

220220
def __del__(self):
221-
if hasattr(self, "_thread_event"):
221+
if hasattr(self, "_thread_event") and hasattr(self, "_heartbeat_thread"):
222222
self._thread_event.set()
223+
if self._heartbeat_thread.is_alive():
224+
self._heartbeat_thread.join()
223225

224226
def close(self):
225227
"""Exit out of the graphband generator.
@@ -536,5 +538,3 @@ def __iter__(self) -> Iterator[Task[TaskTypeVar]]:
536538
break
537539
if completed_naturally:
538540
self._iterator_completed = True
539-
self._thread_event.set()
540-
self._heartbeat_thread.join()

tests/test_graphband.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_graphband_sequential_success(tmp_path):
9191
with Session(pbar._engine) as session:
9292
workers = session.query(WorkerEntry).all()
9393
assert len(workers) == 1
94-
assert workers[0].status == WorkerStatus.OFFLINE
94+
assert workers[0].status == WorkerStatus.IDLE
9595
tasks = session.query(TaskEntry).all()
9696
assert len(tasks) == 10
9797
for id, task in enumerate(tasks):
@@ -102,6 +102,15 @@ def test_graphband_sequential_success(tmp_path):
102102
# if we no iterate again, we yield nothing
103103
assert list(pbar) == []
104104

105+
# Test that worker goes offline when garbage collected
106+
engine = pbar._engine
107+
del pbar
108+
109+
with Session(engine) as session:
110+
workers = session.query(WorkerEntry).all()
111+
assert len(workers) == 1
112+
assert workers[0].status == WorkerStatus.OFFLINE
113+
105114
pbar = Graphband(
106115
sequential_task(),
107116
db=f"sqlite:///{tmp_path}/graphband.sqlite",
@@ -732,3 +741,78 @@ def test_has_more_jobs_with_killed_workers(tmp_path):
732741
assert retries_worker.has_more_jobs is True
733742
assert len(list(retries_worker)) == 1
734743
assert retries_worker.has_more_jobs is False
744+
745+
746+
def test_resume_worker(tmp_path):
747+
lock_path = f"{tmp_path}/graphband.lock"
748+
db_path = f"sqlite:///{tmp_path}/graphband.sqlite"
749+
750+
worker = Graphband(
751+
sequential_task(),
752+
db=db_path,
753+
lock=Lock(lock_path),
754+
heartbeat_timeout=2,
755+
heartbeat_interval=1,
756+
identifier="worker",
757+
)
758+
engine = create_engine(worker.db)
759+
length = 0
760+
761+
with Session(engine) as session:
762+
worker_entry = session.get(WorkerEntry, "worker")
763+
assert worker_entry is not None
764+
assert worker_entry.status == WorkerStatus.IDLE
765+
766+
for item in worker:
767+
with Session(engine) as session:
768+
worker_entry = session.get(WorkerEntry, "worker")
769+
assert worker_entry is not None
770+
assert worker_entry.status == WorkerStatus.BUSY
771+
length += 1
772+
if item.id == "task_5":
773+
break
774+
assert length == 6
775+
776+
with Session(engine) as session:
777+
worker_entry = session.get(WorkerEntry, "worker")
778+
assert worker_entry is not None
779+
assert worker_entry.status == WorkerStatus.IDLE
780+
781+
for item in worker:
782+
with Session(engine) as session:
783+
worker_entry = session.get(WorkerEntry, "worker")
784+
assert worker_entry is not None
785+
assert worker_entry.status == WorkerStatus.BUSY
786+
length += 1
787+
788+
assert length == 10
789+
790+
with Session(engine) as session:
791+
worker_entry = session.get(WorkerEntry, "worker")
792+
assert worker_entry is not None
793+
assert worker_entry.status == WorkerStatus.IDLE
794+
795+
del worker
796+
797+
with Session(engine) as session:
798+
worker_entry = session.get(WorkerEntry, "worker")
799+
assert worker_entry is not None
800+
assert worker_entry.status == WorkerStatus.OFFLINE
801+
802+
proc = multiprocessing.Process(
803+
target=task_worker,
804+
args=(sequential_task, lock_path, db_path, tmp_path / "test.txt", 0.1),
805+
kwargs={
806+
"heartbeat_timeout": 1,
807+
"heartbeat_interval": 0.5,
808+
"max_killed_retries": 0, # No retries allowed for killed tasks
809+
"identifier": "killed-worker",
810+
},
811+
)
812+
proc.start()
813+
proc.join(timeout=5)
814+
815+
with Session(engine) as session:
816+
worker_entry = session.get(WorkerEntry, "killed-worker")
817+
assert worker_entry is not None
818+
assert worker_entry.status == WorkerStatus.OFFLINE

0 commit comments

Comments
 (0)