Skip to content

Commit 17bf10f

Browse files
mihowclaude
andauthored
fix: PSv2 follow-up fixes from integration tests (RolnickLab#1135)
* fix: prevent NATS connection flooding and stale job task fetching - Add connect_timeout=5, allow_reconnect=False to NATS connections to prevent leaked reconnection loops from blocking Django's event loop - Guard /tasks endpoint against terminal-status jobs (return empty tasks instead of attempting NATS reserve) - IncompleteJobFilter now excludes jobs by top-level status in addition to progress JSON stages - Add stale worker cleanup to integration test script Found during PSv2 integration testing where stale ADC workers with default DataLoader parallelism overwhelmed the single uvicorn worker thread by flooding /tasks with concurrent NATS reserve requests. Co-Authored-By: Claude <noreply@anthropic.com> * docs: PSv2 integration test session notes and NATS flooding findings Session notes from 2026-02-16 integration test including root cause analysis of stale worker task competition and NATS connection issues. Findings doc tracks applied fixes and remaining TODOs with priorities. Co-Authored-By: Claude <noreply@anthropic.com> * docs: update session notes with successful test run #3 PSv2 integration test passed end-to-end (job 1380, 20/20 images). Identified ack_wait=300s as cause of ~5min idle time when GPU processes race for NATS tasks. Co-Authored-By: Claude <noreply@anthropic.com> * fix: batch NATS task fetch to prevent HTTP timeouts Replace N×1 reserve_task() calls with single reserve_tasks() batch fetch. The previous implementation created a new pull subscription per message (320 NATS round trips for batch=64), causing the /tasks endpoint to exceed HTTP client timeouts. The new approach uses one psub.fetch() call for the entire batch. Co-Authored-By: Claude <noreply@anthropic.com> * docs: add next session prompt * feat: add pipeline__slug__in filter for multi-pipeline job queries Workers that handle multiple pipelines can now fetch jobs for all of them in a single request: ?pipeline__slug__in=slug1,slug2 Co-Authored-By: Claude <noreply@anthropic.com> * chore: remove local-only docs and scripts from branch These files are session notes, planning docs, and test scripts that should stay local rather than be part of the PR. Co-Authored-By: Claude <noreply@anthropic.com> * feat: set job dispatch_mode at creation time based on project feature flags ML jobs with a pipeline now get dispatch_mode set during setup() instead of waiting until run() is called by the Celery worker. This lets the UI show the correct mode immediately after job creation. Co-Authored-By: Claude <noreply@anthropic.com> * fix: add timeouts to all JetStream operations and restore reconnect policy Add NATS_JETSTREAM_TIMEOUT (10s) to all JetStream metadata operations via asyncio.wait_for() so a hung NATS connection fails fast instead of blocking the caller's thread indefinitely. Also restore the intended reconnect policy (2 attempts, 1s wait) that was lost in a prior force push. Co-Authored-By: Claude <noreply@anthropic.com> * fix: propagate NATS timeouts as 503 instead of swallowing them asyncio.TimeoutError from _ensure_stream() and _ensure_consumer() was caught by the broad `except Exception` in reserve_tasks(), silently returning [] and making NATS outages indistinguishable from empty queues. Workers would then poll immediately, recreating the flooding problem. - Add explicit `except asyncio.TimeoutError: raise` in reserve_tasks() - Catch TimeoutError and OSError in the /tasks view, return 503 - Restore allow_reconnect=False (fail-fast on connection issues) - Add return type annotation to get_connection() Co-Authored-By: Claude <noreply@anthropic.com> * fix: address review comments (log level, fetch timeout, docstring) - Downgrade reserve_tasks log to DEBUG when zero tasks reserved (avoid log spam from frequent polling) - Pass timeout=0.5 from /tasks endpoint to avoid blocking the worker for 5s on empty queues - Fix docstring examples using string 'job123' for int-typed job_id Co-Authored-By: Claude <noreply@anthropic.com> * fix: catch nats.errors.Error in /tasks endpoint for proper 503 responses NoServersError, ConnectionClosedError, and other NATS exceptions inherit from nats.errors.Error (not OSError), so they escaped the handler and returned 500 instead of 503. Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent cf3a59e commit 17bf10f

File tree

5 files changed

+220
-94
lines changed

5 files changed

+220
-94
lines changed

ami/jobs/models.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,7 @@ def run(cls, job: "Job"):
461461
# End image collection stage
462462
job.save()
463463

464-
if job.project.feature_flags.async_pipeline_workers:
465-
job.dispatch_mode = JobDispatchMode.ASYNC_API
466-
job.save(update_fields=["dispatch_mode"])
464+
if job.dispatch_mode == JobDispatchMode.ASYNC_API:
467465
queued = queue_images_to_nats(job, images)
468466
if not queued:
469467
job.logger.error("Aborting job %s because images could not be queued to NATS", job.pk)
@@ -473,8 +471,6 @@ def run(cls, job: "Job"):
473471
job.save()
474472
return
475473
else:
476-
job.dispatch_mode = JobDispatchMode.SYNC_API
477-
job.save(update_fields=["dispatch_mode"])
478474
cls.process_images(job, images)
479475

480476
@classmethod
@@ -919,6 +915,15 @@ def setup(self, save=True):
919915
self.progress.add_stage_param(delay_stage.key, "Mood", "😴")
920916

921917
if self.pipeline:
918+
# Set dispatch mode based on project feature flags at creation time
919+
# so the UI can show the correct mode before the job runs.
920+
# Only override if still at the default (INTERNAL), to allow explicit overrides.
921+
if self.dispatch_mode == JobDispatchMode.INTERNAL:
922+
if self.project and self.project.feature_flags.async_pipeline_workers:
923+
self.dispatch_mode = JobDispatchMode.ASYNC_API
924+
else:
925+
self.dispatch_mode = JobDispatchMode.SYNC_API
926+
922927
collect_stage = self.progress.add_stage("Collect")
923928
self.progress.add_stage_param(collect_stage.key, "Total Images", "")
924929

ami/jobs/tests.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,36 @@ def test_filter_by_pipeline_slug(self):
384384
self.assertEqual(data["count"], 1)
385385
self.assertEqual(data["results"][0]["id"], job_with_pipeline.pk)
386386

387+
def test_filter_by_pipeline_slug_in(self):
388+
"""Test filtering jobs by pipeline__slug__in (multiple slugs)."""
389+
pipeline_a = self._create_pipeline("Pipeline A", "pipeline-a")
390+
pipeline_b = Pipeline.objects.create(name="Pipeline B", slug="pipeline-b", description="B")
391+
pipeline_b.projects.add(self.project)
392+
pipeline_c = Pipeline.objects.create(name="Pipeline C", slug="pipeline-c", description="C")
393+
pipeline_c.projects.add(self.project)
394+
395+
job_a = self._create_ml_job("Job A", pipeline_a)
396+
job_b = self._create_ml_job("Job B", pipeline_b)
397+
job_c = self._create_ml_job("Job C", pipeline_c)
398+
399+
self.client.force_authenticate(user=self.user)
400+
401+
# Filter for two of the three pipelines
402+
jobs_list_url = reverse_with_params(
403+
"api:job-list",
404+
params={"project_id": self.project.pk, "pipeline__slug__in": "pipeline-a,pipeline-b"},
405+
)
406+
resp = self.client.get(jobs_list_url)
407+
408+
self.assertEqual(resp.status_code, 200)
409+
data = resp.json()
410+
returned_ids = {job["id"] for job in data["results"]}
411+
self.assertIn(job_a.pk, returned_ids)
412+
self.assertIn(job_b.pk, returned_ids)
413+
self.assertNotIn(job_c.pk, returned_ids)
414+
# Original setUp job (no pipeline) should also be excluded
415+
self.assertNotIn(self.job.pk, returned_ids)
416+
387417
def test_search_jobs(self):
388418
"""Test searching jobs by name and pipeline name."""
389419
pipeline = self._create_pipeline("SearchablePipeline", "searchable-pipeline")
@@ -571,13 +601,11 @@ def test_dispatch_mode_filtering(self):
571601
dispatch_mode=JobDispatchMode.ASYNC_API,
572602
)
573603

574-
# Create a job with default dispatch_mode (should be "internal")
604+
# Create a non-ML job without a pipeline (dispatch_mode stays "internal")
575605
internal_job = Job.objects.create(
576-
job_type_key=MLJob.key,
606+
job_type_key="data_storage_sync",
577607
project=self.project,
578608
name="Internal Job",
579-
pipeline=self.pipeline,
580-
source_image_collection=self.source_image_collection,
581609
)
582610

583611
self.client.force_authenticate(user=self.user)
@@ -614,6 +642,39 @@ def test_dispatch_mode_filtering(self):
614642
expected_ids = {sync_job.pk, async_job.pk, internal_job.pk}
615643
self.assertEqual(returned_ids, expected_ids)
616644

645+
def test_ml_job_dispatch_mode_set_on_creation(self):
646+
"""Test that ML jobs get dispatch_mode set based on project feature flags at creation time."""
647+
# Without async flag, ML job should default to sync_api
648+
sync_job = Job.objects.create(
649+
job_type_key=MLJob.key,
650+
project=self.project,
651+
name="Auto Sync Job",
652+
pipeline=self.pipeline,
653+
source_image_collection=self.source_image_collection,
654+
)
655+
self.assertEqual(sync_job.dispatch_mode, JobDispatchMode.SYNC_API)
656+
657+
# Enable async flag on project
658+
self.project.feature_flags.async_pipeline_workers = True
659+
self.project.save()
660+
661+
async_job = Job.objects.create(
662+
job_type_key=MLJob.key,
663+
project=self.project,
664+
name="Auto Async Job",
665+
pipeline=self.pipeline,
666+
source_image_collection=self.source_image_collection,
667+
)
668+
self.assertEqual(async_job.dispatch_mode, JobDispatchMode.ASYNC_API)
669+
670+
# Non-pipeline job should stay internal regardless of feature flag
671+
internal_job = Job.objects.create(
672+
job_type_key="data_storage_sync",
673+
project=self.project,
674+
name="Internal Job",
675+
)
676+
self.assertEqual(internal_job.dispatch_mode, JobDispatchMode.INTERNAL)
677+
617678
def test_tasks_endpoint_rejects_non_async_jobs(self):
618679
"""Test that /tasks endpoint returns 400 for non-async_api jobs."""
619680
from ami.base.serializers import reverse_with_params

ami/jobs/views.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import logging
23

4+
import nats.errors
35
import pydantic
46
from asgiref.sync import async_to_sync
57
from django.db.models import Q
@@ -32,6 +34,7 @@ class JobFilterSet(filters.FilterSet):
3234
"""Custom filterset to enable pipeline name filtering."""
3335

3436
pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact")
37+
pipeline__slug__in = filters.BaseInFilter(field_name="pipeline__slug", lookup_expr="in")
3538

3639
class Meta:
3740
model = Job
@@ -55,11 +58,12 @@ def filter_queryset(self, request, queryset, view):
5558
incomplete_only = url_boolean_param(request, "incomplete_only", default=False)
5659
# Filter to incomplete jobs if requested (checks "results" stage status)
5760
if incomplete_only:
58-
# Create filters for each final state to exclude
61+
# Exclude jobs with a terminal top-level status
62+
queryset = queryset.exclude(status__in=JobState.final_states())
63+
64+
# Also exclude jobs where the "results" stage has a final state status
5965
final_states = JobState.final_states()
6066
exclude_conditions = Q()
61-
62-
# Exclude jobs where the "results" stage has a final state status
6367
for state in final_states:
6468
# JSON path query to check if results stage status is in final states
6569
# @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere
@@ -233,6 +237,10 @@ def tasks(self, request, pk=None):
233237
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
234238
raise ValidationError("Only async_api jobs have fetchable tasks")
235239

240+
# Don't fetch tasks from completed/failed/revoked jobs
241+
if job.status in JobState.final_states():
242+
return Response({"tasks": []})
243+
236244
# Validate that the job has a pipeline
237245
if not job.pipeline:
238246
raise ValidationError("This job does not have a pipeline configured")
@@ -241,16 +249,14 @@ def tasks(self, request, pk=None):
241249
from ami.ml.orchestration.nats_queue import TaskQueueManager
242250

243251
async def get_tasks():
244-
tasks = []
245252
async with TaskQueueManager() as manager:
246-
for _ in range(batch):
247-
task = await manager.reserve_task(job.pk, timeout=0.1)
248-
if task:
249-
tasks.append(task.dict())
250-
return tasks
251-
252-
# Use async_to_sync to properly handle the async call
253-
tasks = async_to_sync(get_tasks)()
253+
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch, timeout=0.5)]
254+
255+
try:
256+
tasks = async_to_sync(get_tasks)()
257+
except (asyncio.TimeoutError, OSError, nats.errors.Error) as e:
258+
logger.warning("NATS unavailable while fetching tasks for job %s: %s", job.pk, e)
259+
return Response({"error": "Task queue temporarily unavailable"}, status=503)
254260

255261
return Response({"tasks": tasks})
256262

0 commit comments

Comments
 (0)