Skip to content

Commit 75daa56

Browse files
add Graphband.has_more_jobs (#38)
* add simple test case * add additional test case * test killed (apparently) * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * lint * test killed retries * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 83d9fe2 commit 75daa56

File tree

2 files changed

+321
-0
lines changed

2 files changed

+321
-0
lines changed

laufband/graphband.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def __init__(
161161
# here we keep track of failed job data to be retried later
162162
self._failed_job_cache = {}
163163
self._iterator = None
164+
self._iterator_completed = False
164165
self._heartbeat_timeout = heartbeat_timeout
165166
self._heartbeat_interval = heartbeat_interval
166167
self._labels = frozenset(labels or [])
@@ -257,6 +258,114 @@ def db(self) -> str:
257258
"""Return the database URL."""
258259
return self._db
259260

261+
@property
262+
def has_more_jobs(self) -> bool:
263+
"""Check if there are any more jobs that this worker could process.
264+
265+
Returns
266+
-------
267+
bool
268+
True if there are jobs that this worker could potentially process,
269+
False if all jobs are either completed or permanently failed/killed.
270+
271+
Notes
272+
-----
273+
Simple logic: Check if there are processable jobs in the failed cache
274+
or incomplete tasks in the database that match our labels and haven't
275+
exceeded retry limits.
276+
277+
Jobs with label mismatches are ignored since this worker would never
278+
pick them up.
279+
280+
Examples
281+
--------
282+
>>> pbar = Graphband(tasks, labels={'worker-a'})
283+
>>> list(pbar) # Process all available jobs
284+
>>> pbar.has_more_jobs # Should be False if no more jobs for this worker
285+
False
286+
"""
287+
if self.disabled:
288+
return False
289+
290+
# Simple logic - check failed cache and database
291+
retryable_failed_jobs = 0
292+
incomplete_jobs = 0
293+
294+
with self.lock:
295+
with Session(self._engine) as session:
296+
# Check failed job cache
297+
for task in self._failed_job_cache.values():
298+
if not task.requirements.issubset(self.labels):
299+
continue
300+
301+
task_entry = session.get(TaskEntry, task.id)
302+
if task_entry is None:
303+
retryable_failed_jobs += 1 # New task, can be processed
304+
elif (
305+
task_entry.failed_retries < self._max_failed_retries
306+
and task_entry.killed_retries < self._max_killed_retries
307+
):
308+
retryable_failed_jobs += 1
309+
310+
# Check database for incomplete tasks
311+
workflow = (
312+
session.query(WorkflowEntry)
313+
.filter(WorkflowEntry.id == "main")
314+
.first()
315+
)
316+
if workflow is not None:
317+
for task_entry in workflow.tasks:
318+
# Skip if labels don't match
319+
if not set(task_entry.requirements).issubset(self.labels):
320+
continue
321+
322+
# Count incomplete tasks that can be retried
323+
if (
324+
not task_entry.completed
325+
and task_entry.failed_retries < self._max_failed_retries
326+
and task_entry.killed_retries < self._max_killed_retries
327+
):
328+
incomplete_jobs += 1
329+
330+
# If no incomplete jobs found but iterator hasn't completed,
331+
# check if we have any tasks matching our labels
332+
if incomplete_jobs == 0 and not self._iterator_completed:
333+
# Check if all tasks matching our labels are complete
334+
total_matching_tasks = len(
335+
[
336+
t
337+
for t in workflow.tasks
338+
if set(t.requirements).issubset(self.labels)
339+
]
340+
)
341+
completed_matching_tasks = len(
342+
[
343+
t
344+
for t in workflow.tasks
345+
if set(t.requirements).issubset(self.labels)
346+
and t.completed
347+
]
348+
)
349+
350+
# If we have tasks in DB and all matching ones are complete,
351+
# and no retryable failed jobs, then no more jobs
352+
if (
353+
total_matching_tasks > 0
354+
and completed_matching_tasks == total_matching_tasks
355+
and retryable_failed_jobs == 0
356+
):
357+
return False
358+
359+
# Otherwise, there might be more tasks to process
360+
return True
361+
else:
362+
# No workflow exists yet - if we haven't completed iteration,
363+
# assume there are jobs from the original graph
364+
if not self._iterator_completed:
365+
return True
366+
367+
return (retryable_failed_jobs + incomplete_jobs) > 0
368+
260369
def __iter__(self) -> Iterator[Task[TaskTypeVar]]:
261370
"""The generator that handles the iteration logic."""
262371

@@ -278,6 +387,7 @@ def __iter__(self) -> Iterator[Task[TaskTypeVar]]:
278387
yield from self._iterator
279388
return
280389

390+
completed_naturally = True
281391
for task in self._iterator:
282392
if not task.requirements.issubset(self.labels):
283393
log.debug(
@@ -389,6 +499,7 @@ def __iter__(self) -> Iterator[Task[TaskTypeVar]]:
389499

390500
worker.status = WorkerStatus.IDLE
391501
session.commit()
502+
completed_naturally = False
392503
break
393504
with self.lock:
394505
with Session(self._engine) as session:
@@ -421,6 +532,9 @@ def __iter__(self) -> Iterator[Task[TaskTypeVar]]:
421532
worker.status = WorkerStatus.IDLE
422533
session.commit()
423534
if self._close_trigger:
535+
completed_naturally = False
424536
break
537+
if completed_naturally:
538+
self._iterator_completed = True
425539
self._thread_event.set()
426540
self._heartbeat_thread.join()

tests/test_graphband.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,3 +525,210 @@ def test_multi_dependency_graph_task(tmp_path):
525525
assert sorted(
526526
[e.id for e in entries[task_id].current_status.dependencies]
527527
) == sorted(deps)
528+
529+
530+
def test_has_more_jobs(tmp_path):
531+
worker = Graphband(
532+
sequential_task(),
533+
db=f"sqlite:///{tmp_path}/graphband.sqlite",
534+
lock=Lock(f"{tmp_path}/graphband.lock"),
535+
)
536+
assert worker.has_more_jobs is True
537+
for item in worker:
538+
if item.id == "task_5":
539+
break
540+
# there are still 4 remaining tasks
541+
assert worker.has_more_jobs is True
542+
543+
assert len(list(worker)) == 4
544+
# won't pick up the failed job
545+
assert worker.has_more_jobs is False
546+
547+
w2 = Graphband(
548+
sequential_task(),
549+
db=f"sqlite:///{tmp_path}/graphband.sqlite",
550+
lock=Lock(f"{tmp_path}/graphband.lock"),
551+
max_failed_retries=2,
552+
identifier="retry-worker",
553+
)
554+
# will pick up the failed job
555+
assert w2.has_more_jobs is True
556+
assert len(list(w2)) == 1
557+
assert w2.has_more_jobs is False
558+
559+
w3 = Graphband(
560+
sequential_task(),
561+
db=f"sqlite:///{tmp_path}/graphband.sqlite",
562+
lock=Lock(f"{tmp_path}/graphband.lock"),
563+
max_failed_retries=2,
564+
identifier="retry-worker-2",
565+
)
566+
assert w3.has_more_jobs is False # all jobs are no completed successfully
567+
568+
569+
def blocked_dependency_graph_task():
570+
"""Create a graph where task b blocks task c due to label requirements.
571+
572+
Graph structure:
573+
a --> b --> c
574+
Where b requires 'special-worker' label, but a and c require 'main' label.
575+
"""
576+
digraph = nx.DiGraph()
577+
edges = [
578+
("a", "b"),
579+
("b", "c"),
580+
]
581+
digraph.add_edges_from(edges)
582+
digraph.nodes["b"]["requirements"] = {"special-worker"}
583+
for node in nx.topological_sort(digraph):
584+
yield Task(
585+
id=node,
586+
data=node,
587+
dependencies=set(digraph.predecessors(node)),
588+
requirements=digraph.nodes[node].get("requirements", {"main"}),
589+
)
590+
591+
592+
def test_has_more_jobs_with_blocked_dependencies(tmp_path):
593+
"""Test has_more_jobs when dependencies are blocked by label mismatches.
594+
595+
This test verifies the scenario where:
596+
- a --> b --> c dependency chain
597+
- b needs a different label than a, c
598+
- has_more_jobs for the a/c worker should be true until c is completed
599+
- but the worker won't be able to pick up c because b isn't completed
600+
"""
601+
# Worker that can process 'main' tasks (a and c) but not 'special-worker' tasks (b)
602+
main_worker = Graphband(
603+
blocked_dependency_graph_task(),
604+
db=f"sqlite:///{tmp_path}/graphband.sqlite",
605+
lock=Lock(f"{tmp_path}/graphband.lock"),
606+
labels={"main"},
607+
identifier="main-worker",
608+
)
609+
610+
# Initially should have jobs available
611+
assert main_worker.has_more_jobs is True
612+
613+
# Process available tasks - should only get task 'a'
614+
items = list(main_worker)
615+
assert len(items) == 1
616+
assert items[0].id == "a"
617+
618+
# After processing 'a', should still have more jobs (task 'c' exists but is blocked)
619+
# This is the key behavior: has_more_jobs returns True even though this worker
620+
# cannot make progress because 'c' depends on 'b' which requires different labels
621+
assert main_worker.has_more_jobs is True
622+
623+
# Trying to iterate again should yield nothing since 'c' is blocked by 'b'
624+
items_second = list(main_worker)
625+
assert len(items_second) == 0
626+
627+
# Should still report more jobs available (task 'c' is incomplete)
628+
assert main_worker.has_more_jobs is True
629+
630+
# Now create a worker that can process the 'special-worker' task 'b'
631+
special_worker = Graphband(
632+
blocked_dependency_graph_task(),
633+
db=f"sqlite:///{tmp_path}/graphband.sqlite",
634+
lock=Lock(f"{tmp_path}/graphband.lock"),
635+
labels={"special-worker"},
636+
identifier="special-worker",
637+
)
638+
639+
# Special worker should have jobs (task 'b')
640+
assert special_worker.has_more_jobs is True
641+
642+
# Process task 'b'
643+
items_special = list(special_worker)
644+
assert len(items_special) == 1
645+
assert items_special[0].id == "b"
646+
647+
# After 'b' is complete, special worker should have no more jobs
648+
assert special_worker.has_more_jobs is False
649+
650+
# Now main worker should be able to process task 'c'
651+
assert main_worker.has_more_jobs is True
652+
653+
items_final = list(main_worker)
654+
assert len(items_final) == 1
655+
assert items_final[0].id == "c"
656+
657+
# Finally, no more jobs for either worker
658+
assert main_worker.has_more_jobs is False
659+
assert special_worker.has_more_jobs is False
660+
661+
662+
def test_has_more_jobs_with_killed_workers(tmp_path):
663+
"""Test has_more_jobs behavior when workers are killed
664+
and tasks exceed retry limits."""
665+
# Test case where killed tasks cannot be retried (max_killed_retries=0)
666+
lock_path = f"{tmp_path}/graphband.lock"
667+
db = f"sqlite:///{tmp_path}/graphband.sqlite"
668+
file = f"{tmp_path}/output.txt"
669+
670+
# Start a worker that will be killed with no retry allowance
671+
proc = multiprocessing.Process(
672+
target=task_worker,
673+
args=(sequential_task, lock_path, db, file, 3),
674+
kwargs={
675+
"heartbeat_timeout": 1,
676+
"heartbeat_interval": 0.5,
677+
"max_killed_retries": 0, # No retries allowed for killed tasks
678+
"identifier": "killed-worker",
679+
},
680+
)
681+
proc.start()
682+
time.sleep(2) # Let the worker start one task
683+
proc.kill()
684+
proc.join()
685+
686+
time.sleep(2)
687+
# need to start another worker to mark the job as killed in the db
688+
_ = Graphband(
689+
sequential_task(),
690+
db=db,
691+
lock=Lock(lock_path),
692+
heartbeat_timeout=2,
693+
heartbeat_interval=1,
694+
identifier="update-worker",
695+
)
696+
time.sleep(2)
697+
698+
# Verify the killed task is permanently blocked
699+
engine = create_engine(db)
700+
with Session(engine) as session:
701+
tasks = session.query(TaskEntry).all()
702+
# Verify we have exactly one killed task and 9 completed tasks
703+
killed_tasks = [
704+
t for t in tasks if t.current_status.status == TaskStatusEnum.KILLED
705+
]
706+
assert len(killed_tasks) == 1
707+
708+
no_retries_worker = Graphband(
709+
sequential_task(),
710+
db=db,
711+
lock=Lock(lock_path),
712+
heartbeat_timeout=2,
713+
heartbeat_interval=1,
714+
max_killed_retries=0, # No retries allowed
715+
identifier="no-retries-worker",
716+
)
717+
718+
retries_worker = Graphband(
719+
sequential_task(),
720+
db=db,
721+
lock=Lock(lock_path),
722+
heartbeat_timeout=2,
723+
heartbeat_interval=1,
724+
max_killed_retries=2, # Allow retries
725+
identifier="retries-worker",
726+
)
727+
728+
assert no_retries_worker.has_more_jobs is True
729+
assert retries_worker.has_more_jobs is True
730+
assert len(list(no_retries_worker)) == 9
731+
assert no_retries_worker.has_more_jobs is False
732+
assert retries_worker.has_more_jobs is True
733+
assert len(list(retries_worker)) == 1
734+
assert retries_worker.has_more_jobs is False

0 commit comments

Comments
 (0)