Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 52 additions & 25 deletions src/hope_dedup_engine/apps/api/deduplication/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
from django.utils import timezone

import sentry_sdk
from celery import chord, shared_task, group
from celery import shared_task

from hope_dedup_engine.apps.api.deduplication.config import DeduplicationSetConfig
from hope_dedup_engine.apps.api.models import MainJob, DeduplicationSet
from hope_dedup_engine.apps.api.models.deduplication import DeduplicationSetGroup
from hope_dedup_engine.apps.api.models.jobs import EncodeChunkJob, DeduplicateDatasetJob
from hope_dedup_engine.apps.api.utils.notification import send_notification

from hope_dedup_engine.apps.faces.celery_tasks import (
get_chunks,
finish_with_error,
finish_with_success,
ChunkPurpose,
)
from hope_dedup_engine.apps.faces.services.facial import dedupe_all, encode_faces

HOUR = 60 * 60
RESCHEDULE_INTERVAL = 6 * HOUR
STALE_PROCESSING_THRESHOLD = 24 * HOUR


def finish_processing(ds: DeduplicationSet, error: Exception | None = None) -> None:
if error:
ds.set_state(DeduplicationSet.State.FAILED, error)
else:
ds.set_state(DeduplicationSet.State.READY)
send_notification(ds)


def try_acquire_processing_lock(deduplication_set: DeduplicationSet) -> DeduplicationSet | None:
with transaction.atomic():
DeduplicationSetGroup.objects.select_for_update().get(pk=deduplication_set.group_id)
Expand All @@ -44,8 +46,17 @@ def try_acquire_processing_lock(deduplication_set: DeduplicationSet) -> Deduplic
return deduplication_set


@shared_task(bind=True, soft_time_limit=0.5 * HOUR, time_limit=1 * HOUR)
@shared_task(bind=True)
def find_duplicates(self, dedup_job_id: int, version: int) -> dict[str, Any]:
"""
Process a deduplication job: encode faces and find duplicates.

This task handles the complete deduplication workflow:
1. Acquires a processing lock on the deduplication set
2. Encodes all images without embeddings
3. Runs deduplication (unless encode_only is True)
4. Updates state and sends notification
"""
main_job: MainJob = MainJob.objects.get(pk=dedup_job_id, version=version)

deduplication_set = try_acquire_processing_lock(main_job.deduplication_set)
Expand All @@ -63,25 +74,41 @@ def find_duplicates(self, dedup_job_id: int, version: int) -> dict[str, Any]:

try:
send_notification(deduplication_set)

encoding_ids = deduplication_set.encodings_without_embeddings().values_list("id", flat=True)
chunks = get_chunks(encoding_ids, purpose=ChunkPurpose.ENCODE)
tasks = [
EncodeChunkJob.objects.create(deduplication_set=deduplication_set, encoding_ids=chunk).s()
for chunk in chunks
]
if main_job.encode_only:
chord_id = group(tasks)()
finish_with_success(deduplication_set)
else:
chord_id = chord(tasks)(DeduplicateDatasetJob.objects.create(deduplication_set_id=deduplication_set.pk).s())
config = DeduplicationSetConfig.from_deduplication_set(deduplication_set)

# Encode all images without embeddings
encoding_ids = list(deduplication_set.encodings_without_embeddings().values_list("id", flat=True))
encodings_count = len(encoding_ids)

if encoding_ids:
encode_faces(
deduplication_set,
encoding_ids,
config.face_confidence_threshold,
config.face_coverage_threshold,
config.deduplicate.model_name,
config.deduplicate.detector_backend,
align=config.deduplicate.align,
)

# Run deduplication unless encode_only
findings_count = 0
if not main_job.encode_only:
findings_count = dedupe_all(
deduplication_set=deduplication_set,
duplicate_confidence_threshold=config.duplicate_confidence_threshold,
model_name=config.deduplicate.model_name,
distance_metric=config.deduplicate.distance_metric,
)

finish_processing(deduplication_set)

return {
"deduplication_set": str(deduplication_set),
"chord_id": str(chord_id),
"chunks": len(chunks),
"encodings_processed": encodings_count,
"findings_created": findings_count,
}
except Exception as e:
finish_with_error(deduplication_set, e)
finish_processing(deduplication_set, e)
sentry_sdk.capture_exception(e)
raise
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("api", "0061_remove_encoding_api_encodin_dedupli_909c41_idx"),
]

operations = [
migrations.DeleteModel(
name="EncodeChunkJob",
),
migrations.DeleteModel(
name="DeduplicateDatasetJob",
),
migrations.DeleteModel(
name="CallbackFindingsJob",
),
migrations.DeleteModel(
name="DedupeChunkJob",
),
]
22 changes: 0 additions & 22 deletions src/hope_dedup_engine/apps/api/models/jobs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from django.contrib.postgres.fields import ArrayField
from django.db import models

from django_celery_boost.models import CeleryTaskModel
Expand All @@ -20,27 +19,6 @@ class MainJob(DedupJob):
celery_task_name = "hope_dedup_engine.apps.api.deduplication.process.find_duplicates"


class EncodeChunkJob(DedupJob):
encoding_ids = ArrayField(models.UUIDField(), help_text="Encoding IDs to encode")

celery_task_name = "hope_dedup_engine.apps.faces.celery_tasks.encode_chunk"


class DedupeChunkJob(DedupJob):
encoding_ids0 = ArrayField(models.UUIDField(), help_text="First batch of encoding IDs to encode")
encoding_ids1 = ArrayField(models.UUIDField(), help_text="Second batch of encoding IDs to encode")

celery_task_name = "hope_dedup_engine.apps.faces.celery_tasks.dedupe_chunk"


class CallbackFindingsJob(DedupJob):
celery_task_name = "hope_dedup_engine.apps.faces.celery_tasks.callback_findings"


class DeduplicateDatasetJob(DedupJob):
celery_task_name = "hope_dedup_engine.apps.faces.celery_tasks.deduplicate_dataset"


class SyncDnnFilesJob(CeleryTaskModel):
force = models.BooleanField(
default=False, help_text="If True, forces the re-download of files even if they already exist locally"
Expand Down
49 changes: 6 additions & 43 deletions src/hope_dedup_engine/apps/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
Encoding,
MainJob,
)
from hope_dedup_engine.apps.api.models.jobs import (
EncodeChunkJob,
DeduplicateDatasetJob,
DedupeChunkJob,
CallbackFindingsJob,
)


class DeduplicationSetSerializer(serializers.ModelSerializer):
Expand Down Expand Up @@ -46,46 +40,15 @@ class Meta:
)

def get_status(self, deduplication_set: DeduplicationSet) -> str:
# EncodeChunkJob and DeduplicateDatasetJob are created inside the
# DedupJob. DedupeChunkJob and CallbackFindingsJob are created inside
# the DeduplicateDatasetJob. So we always have the next job object
# created before the current job is finished

job_managers = (
MainJob.objects.filter(deduplication_set=deduplication_set),
EncodeChunkJob.objects.filter(deduplication_set=deduplication_set),
DeduplicateDatasetJob.objects.filter(deduplication_set=deduplication_set),
DedupeChunkJob.objects.filter(deduplication_set=deduplication_set),
CallbackFindingsJob.objects.filter(deduplication_set=deduplication_set),
)

first_task = True

for job_manager in job_managers:
job = job_manager.order_by("-id").first()

if job is None:
# we only get here if no job was scheduled or the previous task
# finished without being able to create the next task, which
# means some other failure
return self.NOT_SCHEDULED

if (result := job.async_result) is None:
# job record was created but the task is not yet started
if first_task:
return CeleryTaskModel.PENDING

# we had some tasks finished before
return CeleryTaskModel.STARTED
job = MainJob.objects.filter(deduplication_set=deduplication_set).order_by("-id").first()

first_task = False
if job is None:
return self.NOT_SCHEDULED

# if the current task status is SUCCESS, we need to check the next
# one
if (status := result.status) != CeleryTaskModel.SUCCESS:
return status
if (result := job.async_result) is None:
return CeleryTaskModel.PENDING

return CeleryTaskModel.SUCCESS
return result.status


class CreateDeduplicationSetSerializer(serializers.ModelSerializer):
Expand Down
Loading
Loading