diff --git a/.gitignore b/.gitignore index ec7fe59c3e..a0e98ead08 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,10 @@ venv.bak/ env.d/development/*.local env.d/terraform +# Docker +compose.override.yml +docker/auth/*.local + # npm node_modules diff --git a/Makefile b/Makefile index 39d6a76450..aa6c342f6b 100644 --- a/Makefile +++ b/Makefile @@ -179,6 +179,10 @@ demo: ## flush db then create a demo for load testing purpose @$(MANAGE) create_demo .PHONY: demo +index: ## index all documents to remote search + @$(MANAGE) index +.PHONY: index + # Nota bene: Black should come after isort just in case they don't agree... lint: ## lint back-end python sources lint: \ diff --git a/compose.yml b/compose.yml index 88e178e353..7d59ceabb7 100644 --- a/compose.yml +++ b/compose.yml @@ -72,6 +72,9 @@ services: - env.d/development/postgresql.local ports: - "8071:8000" + networks: + - default + - lasuite-net volumes: - ./src/backend:/app - ./data/static:/data/static @@ -219,3 +222,8 @@ services: kc_postgresql: condition: service_healthy restart: true + +networks: + lasuite-net: + name: lasuite-net + driver: bridge diff --git a/src/backend/core/api/serializers.py b/src/backend/core/api/serializers.py index 83afc260d9..eb5dab2e3e 100644 --- a/src/backend/core/api/serializers.py +++ b/src/backend/core/api/serializers.py @@ -801,3 +801,16 @@ class MoveDocumentSerializer(serializers.Serializer): choices=enums.MoveNodePositionChoices.choices, default=enums.MoveNodePositionChoices.LAST_CHILD, ) + + +class FindDocumentSerializer(serializers.Serializer): + """Serializer for Find search requests""" + q = serializers.CharField(required=True) + + def validate_q(self, value): + """Ensure the text field is not empty.""" + + if len(value.strip()) == 0: + raise serializers.ValidationError("Text field cannot be empty.") + + return value diff --git a/src/backend/core/api/viewsets.py b/src/backend/core/api/viewsets.py index ee0c594eb1..2a08d90fb4 100644 --- a/src/backend/core/api/viewsets.py +++ b/src/backend/core/api/viewsets.py @@ -19,6 +19,7 @@ from django.db.models.functions import Left, Length from django.http import Http404, StreamingHttpResponse from django.urls import reverse +from django.utils.decorators import method_decorator from django.utils.functional import cached_property from django.utils.text import capfirst, slugify from django.utils.translation import gettext_lazy as _ @@ -29,6 +30,7 @@ from csp.constants import NONE from csp.decorators import csp_update from lasuite.malware_detection import malware_detection +from lasuite.oidc_login.decorators import refresh_oidc_access_token from rest_framework import filters, status, viewsets from rest_framework import response as drf_response from rest_framework.permissions import AllowAny @@ -37,6 +39,7 @@ from core import authentication, choices, enums, models from core.services.ai_services import AIService from core.services.collaboration_services import CollaborationService +from core.services.search_indexers import FindDocumentIndexer from core.tasks.mail import send_ask_for_access_mail from core.utils import extract_attachments, filter_descendants @@ -48,6 +51,12 @@ # pylint: disable=too-many-ancestors +class ServiceUnavailable(drf.exceptions.APIException): + status_code = 503 + default_detail = 'Service unavailable.' + default_code = 'service_unavailable' + + class NestedGenericViewSet(viewsets.GenericViewSet): """ A generic Viewset aims to be used in a nested route context. @@ -367,6 +376,7 @@ class DocumentViewSet( list_serializer_class = serializers.ListDocumentSerializer trashbin_serializer_class = serializers.ListDocumentSerializer tree_serializer_class = serializers.ListDocumentSerializer + search_serializer_class = serializers.ListDocumentSerializer def get_queryset(self): """Get queryset performing all annotation and filtering on the document tree structure.""" @@ -980,6 +990,33 @@ def duplicate(self, request, *args, **kwargs): {"id": str(duplicated_document.id)}, status=status.HTTP_201_CREATED ) + @drf.decorators.action(detail=False, methods=["get"], url_path="search") + @method_decorator(refresh_oidc_access_token) + def search(self, request, *args, **kwargs): + access_token = request.session.get("oidc_access_token") + + serializer = serializers.FindDocumentSerializer( + data=request.query_params + ) + serializer.is_valid(raise_exception=True) + + indexer = FindDocumentIndexer() + try: + queryset = indexer.search( + text=serializer.validated_data.get("q", ""), + user=request.user, + token=access_token + ) + except RuntimeError as err: + raise ServiceUnavailable() + + return self.get_response_for_queryset( + queryset, + context={ + "request": request, + }, + ) + @drf.decorators.action(detail=True, methods=["get"], url_path="versions") def versions_list(self, request, *args, **kwargs): """ diff --git a/src/backend/core/management/commands/index.py b/src/backend/core/management/commands/index.py new file mode 100644 index 0000000000..b7eab9505e --- /dev/null +++ b/src/backend/core/management/commands/index.py @@ -0,0 +1,28 @@ +""" +Handle search setup that needs to be done at bootstrap time. +""" + +import logging +import time + +from django.core.management.base import BaseCommand + +from ...services.search_indexers import FindDocumentIndexer + +logger = logging.getLogger("docs.search.bootstrap_search") + + +class Command(BaseCommand): + """Index all documents to remote search service""" + + help = __doc__ + + def handle(self, *args, **options): + """Launch and log search index generation.""" + logger.info("Starting to regenerate Find index...") + start = time.perf_counter() + + FindDocumentIndexer().index() + + duration = time.perf_counter() - start + logger.info("Search index regenerated in %.2f seconds.", duration) diff --git a/src/backend/core/models.py b/src/backend/core/models.py index a1182964da..2e4b4c6e29 100644 --- a/src/backend/core/models.py +++ b/src/backend/core/models.py @@ -20,7 +20,9 @@ from django.core.files.storage import default_storage from django.core.mail import send_mail from django.db import models, transaction +from django.db.models import signals from django.db.models.functions import Left, Length +from django.dispatch import receiver from django.template.loader import render_to_string from django.utils import timezone from django.utils.functional import cached_property @@ -39,6 +41,7 @@ RoleChoices, get_equivalent_link_definition, ) +from .tasks.find import trigger_document_indexer logger = getLogger(__name__) @@ -439,32 +442,35 @@ def __init__(self, *args, **kwargs): def save(self, *args, **kwargs): """Write content to object storage only if _content has changed.""" super().save(*args, **kwargs) - if self._content: - file_key = self.file_key - bytes_content = self._content.encode("utf-8") + self.save_content(self._content) - # Attempt to directly check if the object exists using the storage client. - try: - response = default_storage.connection.meta.client.head_object( - Bucket=default_storage.bucket_name, Key=file_key - ) - except ClientError as excpt: - # If the error is a 404, the object doesn't exist, so we should create it. - if excpt.response["Error"]["Code"] == "404": - has_changed = True - else: - raise + def save_content(self, content): + """Save content to object storage.""" + + file_key = self.file_key + bytes_content = content.encode("utf-8") + + # Attempt to directly check if the object exists using the storage client. + try: + response = default_storage.connection.meta.client.head_object( + Bucket=default_storage.bucket_name, Key=file_key + ) + except ClientError as excpt: + # If the error is a 404, the object doesn't exist, so we should create it. + if excpt.response["Error"]["Code"] == "404": + has_changed = True else: - # Compare the existing ETag with the MD5 hash of the new content. - has_changed = ( - response["ETag"].strip('"') - != hashlib.md5(bytes_content).hexdigest() # noqa: S324 - ) + raise + else: + # Compare the existing ETag with the MD5 hash of the new content. + has_changed = ( + response["ETag"].strip('"') != hashlib.md5(bytes_content).hexdigest() # noqa: S324 + ) - if has_changed: - content_file = ContentFile(bytes_content) - default_storage.save(file_key, content_file) + if has_changed: + content_file = ContentFile(bytes_content) + default_storage.save(file_key, content_file) def is_leaf(self): """ @@ -946,6 +952,16 @@ def restore(self): ) +@receiver(signals.post_save, sender=Document) +def document_post_save(sender, instance, **kwargs): + """ + Asynchronous call to the document indexer at the end of the transaction. + Note : Within the transaction we can have an empty content and a serialization + error. + """ + trigger_document_indexer(instance, on_commit=True) + + class LinkTrace(BaseModel): """ Relation model to trace accesses to a document via a link by a logged-in user. @@ -1171,6 +1187,15 @@ def get_abilities(self, user): } +@receiver(signals.post_save, sender=DocumentAccess) +def document_access_post_save(sender, instance, created, **kwargs): + """ + Asynchronous call to the document indexer at the end of the transaction. + """ + if not created: + trigger_document_indexer(instance.document, on_commit=True) + + class DocumentAskForAccess(BaseModel): """Relation model to ask for access to a document.""" diff --git a/src/backend/core/services/search_indexers.py b/src/backend/core/services/search_indexers.py new file mode 100644 index 0000000000..075c7a9fbf --- /dev/null +++ b/src/backend/core/services/search_indexers.py @@ -0,0 +1,263 @@ +"""Document search index management utilities and indexers""" + +import logging +from abc import ABC, abstractmethod +from collections import defaultdict + +from django.conf import settings +from django.contrib.auth.models import AnonymousUser + +import requests + +from core import models, utils + +logger = logging.getLogger(__name__) + + +def get_batch_accesses_by_users_and_teams(paths): + """ + Get accesses related to a list of document paths, + grouped by users and teams, including all ancestor paths. + """ + # print("paths: ", paths) + ancestor_map = utils.get_ancestor_to_descendants_map( + paths, steplen=models.Document.steplen + ) + ancestor_paths = list(ancestor_map.keys()) + # print("ancestor map: ", ancestor_map) + # print("ancestor paths: ", ancestor_paths) + + access_qs = models.DocumentAccess.objects.filter( + document__path__in=ancestor_paths + ).values("document__path", "user__sub", "team") + + access_by_document_path = defaultdict(lambda: {"users": set(), "teams": set()}) + + for access in access_qs: + ancestor_path = access["document__path"] + user_sub = access["user__sub"] + team = access["team"] + + for descendant_path in ancestor_map.get(ancestor_path, []): + if user_sub: + access_by_document_path[descendant_path]["users"].add(str(user_sub)) + if team: + access_by_document_path[descendant_path]["teams"].add(team) + + return dict(access_by_document_path) + + +def get_visited_document_ids_of(user): + if isinstance(user, AnonymousUser): + return [] + + # TODO : exclude links when user already have a specific access to the doc + qs = models.LinkTrace.objects.filter( + user=user + ).exclude( + document__accesses__user=user, + ) + + return list({ + str(id) for id in qs.values_list("document_id", flat=True) + }) + + +class BaseDocumentIndexer(ABC): + """ + Base class for document indexers. + + Handles batching and access resolution. Subclasses must implement both + `serialize_document()` and `push()` to define backend-specific behavior. + """ + + def __init__(self, batch_size=None): + """ + Initialize the indexer. + + Args: + batch_size (int, optional): Number of documents per batch. + Defaults to settings.SEARCH_INDEXER_BATCH_SIZE. + """ + self.batch_size = batch_size or settings.SEARCH_INDEXER_BATCH_SIZE + + def index(self): + """ + Fetch documents in batches, serialize them, and push to the search backend. + """ + last_id = 0 + while True: + documents_batch = list( + models.Document.objects.filter( + id__gt=last_id, + ).order_by("id")[: self.batch_size] + ) + + if not documents_batch: + break + + doc_paths = [doc.path for doc in documents_batch] + last_id = documents_batch[-1].id + accesses_by_document_path = get_batch_accesses_by_users_and_teams(doc_paths) + + serialized_batch = [ + self.serialize_document(document, accesses_by_document_path) + for document in documents_batch + if document.content + ] + self.push(serialized_batch) + + @abstractmethod + def serialize_document(self, document, accesses): + """ + Convert a Document instance to a JSON-serializable format for indexing. + + Must be implemented by subclasses. + """ + + @abstractmethod + def push(self, data): + """ + Push a batch of serialized documents to the backend. + + Must be implemented by subclasses. + """ + + def search(self, text, user, token): + """ + Search for documents in Find app. + """ + visited_ids = get_visited_document_ids_of(user) + + response = self.search_query(data={ + "q": text, + "visited": visited_ids, + "services": ["docs"], + }, token=token) + + return self.format_response(response) + + @abstractmethod + def search_query(self, data, token) -> dict: + """ + Retreive documents from the Find app API. + + Must be implemented by subclasses. + """ + + @abstractmethod + def format_response(self, data: dict): + """ + Convert the JSON response from Find app as document queryset. + + Must be implemented by subclasses. + """ + + +class FindDocumentIndexer(BaseDocumentIndexer): + """ + Document indexer that pushes documents to La Suite Find app. + """ + + def serialize_document(self, document, accesses): + """ + Convert a Document to the JSON format expected by La Suite Find. + + Args: + document (Document): The document instance. + accesses (dict): Mapping of document ID to user/team access. + + Returns: + dict: A JSON-serializable dictionary. + """ + doc_path = document.path + doc_content = document.content + text_content = utils.base64_yjs_to_text(doc_content) if doc_content else "" + + return { + "id": str(document.id), + "title": document.title or "", + "content": text_content, + "depth": document.depth, + "path": document.path, + "numchild": document.numchild, + "created_at": document.created_at.isoformat(), + "updated_at": document.updated_at.isoformat(), + "users": list(accesses.get(doc_path, {}).get("users", set())), + "groups": list(accesses.get(doc_path, {}).get("teams", set())), + "reach": document.computed_link_reach, + "size": len(text_content.encode("utf-8")), + "is_active": not bool(document.ancestors_deleted_at), + } + + def search_query(self, data, token) -> requests.Response: + """ + Retrieve documents from the Find app API. + + Args: + data (dict): search data + token (str): OICD token + + Returns: + dict: A JSON-serializable dictionary. + """ + url = getattr(settings, "SEARCH_INDEXER_QUERY_URL", None) + + if not url: + raise RuntimeError( + "SEARCH_INDEXER_QUERY_URL must be set in Django settings before search." + ) + + try: + response = requests.post( + url, + json=data, + headers={"Authorization": f"Bearer {token}"}, + timeout=10, + ) + response.raise_for_status() + return response.json() + except requests.exceptions.HTTPError as e: + logger.error("HTTPError: %s", e) + logger.error("Response content: %s", response.text) # type: ignore + raise + + def format_response(self, data: dict): + """ + Retrieve documents ids from Find app response and return a queryset. + """ + return models.Document.objects.filter(pk__in=[ + d['_id'] for d in data + ]) + + def push(self, data): + """ + Push a batch of documents to the Find backend. + + Args: + data (list): List of document dictionaries. + """ + url = getattr(settings, "SEARCH_INDEXER_URL", None) + if not url: + raise RuntimeError( + "SEARCH_INDEXER_URL must be set in Django settings before indexing." + ) + + secret = getattr(settings, "SEARCH_INDEXER_SECRET", None) + if not secret: + raise RuntimeError( + "SEARCH_INDEXER_SECRET must be set in Django settings before indexing." + ) + + try: + response = requests.post( + url, + json=data, + headers={"Authorization": f"Bearer {secret}"}, + timeout=10, + ) + response.raise_for_status() + except requests.exceptions.HTTPError as e: + logger.error("HTTPError: %s", e) + logger.error("Response content: %s", response.text) + raise diff --git a/src/backend/core/tasks/find.py b/src/backend/core/tasks/find.py new file mode 100644 index 0000000000..858c83ee01 --- /dev/null +++ b/src/backend/core/tasks/find.py @@ -0,0 +1,96 @@ +"""Trigger document indexation using celery task.""" + +from logging import getLogger + +from django.conf import settings +from django.core.cache import cache +from django.db import transaction + +from core import models +from core.services.search_indexers import ( + FindDocumentIndexer, + get_batch_accesses_by_users_and_teams, +) + +from impress.celery_app import app + +logger = getLogger(__file__) + + +def document_indexer_debounce_key(document_id): + """Returns debounce cache key""" + return f"doc-indexer-debounce-{document_id}" + + +def incr_counter(key): + """Increase or reset counter""" + try: + return cache.incr(key) + except ValueError: + cache.set(key, 1) + return 1 + + +def decr_counter(key): + """Decrease or reset counter""" + try: + return cache.decr(key) + except ValueError: + cache.set(key, 0) + return 0 + + +@app.task +def document_indexer_task(document_id): + """Send indexation query for a document using celery task.""" + key = document_indexer_debounce_key(document_id) + + # check if the counter : if still up, skip the task. only the last one + # within the countdown delay will do the query. + if decr_counter(key) > 0: + logger.info("Skip document %s indexation", document_id) + return + + doc = models.Document.objects.get(pk=document_id) + indexer = FindDocumentIndexer() + accesses = get_batch_accesses_by_users_and_teams((doc.path,)) + + data = indexer.serialize_document(document=doc, accesses=accesses) + + logger.info("Start document %s indexation", document_id) + indexer.push(data) + + +def trigger_document_indexer(document, on_commit=False): + """ + Trigger indexation task with debounce a delay set by the SEARCH_INDEXER_COUNTDOWN setting. + + Args: + document (Document): The document instance. + on_commit (bool): Wait for the end of the transaction before starting the task + (some fields may be in wrong state within the transaction) + """ + + if document.deleted_at or document.ancestors_deleted_at: + pass + + if on_commit: + + def _aux(): + trigger_document_indexer(document, on_commit=False) + + transaction.on_commit(_aux) + else: + key = document_indexer_debounce_key(document.pk) + countdown = getattr(settings, "SEARCH_INDEXER_COUNTDOWN", 1) + + logger.info( + "Add task for document %s indexation in %.2f seconds", + document.pk, countdown + ) + + # Each time this method is called during the countdown, we increment the + # counter and each task decrease it, so the index be run only once. + incr_counter(key) + + document_indexer_task.apply_async(args=[document.pk], countdown=countdown) diff --git a/src/backend/core/tests/commands/test_index.py b/src/backend/core/tests/commands/test_index.py new file mode 100644 index 0000000000..1da70344e1 --- /dev/null +++ b/src/backend/core/tests/commands/test_index.py @@ -0,0 +1,49 @@ +""" +Unit test for `index` command. +""" + +from unittest import mock + +from django.core.management import call_command +from django.db import transaction + +import pytest + +from core import factories +from core.services.search_indexers import FindDocumentIndexer + + +@pytest.mark.django_db +def test_index(): + """Test the command `index` that run the Find app indexer for all the available documents.""" + user = factories.UserFactory() + indexer = FindDocumentIndexer() + + with transaction.atomic(): + doc = factories.DocumentFactory() + empty_doc = factories.DocumentFactory(title=None, content='') + no_title_doc = factories.DocumentFactory(title=None) + + factories.UserDocumentAccessFactory(document=doc, user=user) + factories.UserDocumentAccessFactory(document=empty_doc, user=user) + factories.UserDocumentAccessFactory(document=no_title_doc, user=user) + + accesses = { + str(doc.path): {"users": [user.sub]}, + str(empty_doc.path): {"users": [user.sub]}, + str(no_title_doc.path): {"users": [user.sub]}, + } + + def sortkey(d): + return d["id"] + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + call_command("index") + + push_call_args = [call.args[0] for call in mock_push.call_args_list] + + assert len(push_call_args) == 1 # called once but with a batch of docs + assert sorted(push_call_args[0], key=sortkey) == sorted([ + indexer.serialize_document(doc, accesses), + indexer.serialize_document(no_title_doc, accesses), + ], key=sortkey) diff --git a/src/backend/core/tests/documents/test_api_documents_search.py b/src/backend/core/tests/documents/test_api_documents_search.py new file mode 100644 index 0000000000..cf0bf5690d --- /dev/null +++ b/src/backend/core/tests/documents/test_api_documents_search.py @@ -0,0 +1,129 @@ +""" +Tests for Documents API endpoint in impress's core app: list +""" +import responses + +import pytest +from faker import Faker +from rest_framework.test import APIClient + +from core import factories, models + +fake = Faker() +pytestmark = pytest.mark.django_db + + +@pytest.mark.parametrize("role", models.LinkRoleChoices.values) +@pytest.mark.parametrize("reach", models.LinkReachChoices.values) +def test_api_documents_search_anonymous(reach, role): + """ + Anonymous users should not be allowed to search documents whatever the + link reach and link role + """ + factories.DocumentFactory(link_reach=reach, link_role=role) + + response = APIClient().get("/api/v1.0/documents/search/", data={"q": "alpha"}) + + assert response.status_code == 200 + assert response.json() == { + "count": 0, + "next": None, + "previous": None, + "results": [], + } + + +def test_api_documents_search_endpoint_is_none(settings): + """Missing SEARCH_INDEXER_QUERY_URL should throw an error""" + settings.SEARCH_INDEXER_QUERY_URL = None + + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + response = APIClient().get("/api/v1.0/documents/search/", data={"q": "alpha"}) + + assert response.status_code == 503 + assert response.json() == { + 'detail': 'Service unavailable.' + } + + +@responses.activate +def test_api_documents_search_invalid_params(settings): + """Validate the format of documents as returned by the search view.""" + settings.SEARCH_INDEXER_QUERY_URL = "http://find/api/v1.0/search" + + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + response = APIClient().get("/api/v1.0/documents/search/") + + assert response.status_code == 400 + assert response.json() == { + 'q': ['This field is required.'] + } + + +@responses.activate +def test_api_documents_search_format(settings): + """Validate the format of documents as returned by the search view.""" + settings.SEARCH_INDEXER_QUERY_URL = "http://find/api/v1.0/search" + + user = factories.UserFactory() + + client = APIClient() + client.force_login(user) + + user_a, user_b, user_c = factories.UserFactory.create_batch(3) + document = factories.DocumentFactory( + title="alpha", + users=(user_a, user_c), + link_traces=(user, user_b), + ) + access = factories.UserDocumentAccessFactory(document=document, user=user) + + # Find response + responses.add( + responses.POST, + "http://find/api/v1.0/search", + json=[ + {"_id": str(document.pk)}, + ], + status=200, + ) + response = client.get("/api/v1.0/documents/search/", data={"q": "alpha"}) + + assert response.status_code == 200 + content = response.json() + results = content.pop("results") + assert content == { + "count": 1, + "next": None, + "previous": None, + } + assert len(results) == 1 + assert results[0] == { + "id": str(document.id), + "abilities": document.get_abilities(user), + "ancestors_link_reach": None, + "ancestors_link_role": None, + "computed_link_reach": document.computed_link_reach, + "computed_link_role": document.computed_link_role, + "created_at": document.created_at.isoformat().replace("+00:00", "Z"), + "creator": str(document.creator.id), + "depth": 1, + "excerpt": document.excerpt, + "link_reach": document.link_reach, + "link_role": document.link_role, + "nb_accesses_ancestors": 3, + "nb_accesses_direct": 3, + "numchild": 0, + "path": document.path, + "title": document.title, + "updated_at": document.updated_at.isoformat().replace("+00:00", "Z"), + "user_role": access.role, + } diff --git a/src/backend/core/tests/test_models_documents.py b/src/backend/core/tests/test_models_documents.py index 6874009c97..08e689f4a7 100644 --- a/src/backend/core/tests/test_models_documents.py +++ b/src/backend/core/tests/test_models_documents.py @@ -5,6 +5,7 @@ import random import smtplib +import time from logging import Logger from unittest import mock @@ -13,12 +14,15 @@ from django.core.cache import cache from django.core.exceptions import ValidationError from django.core.files.storage import default_storage +from django.db import transaction from django.test.utils import override_settings from django.utils import timezone import pytest from core import factories, models +from core.services.search_indexers import FindDocumentIndexer +from core.tasks.find import document_indexer_debounce_key pytestmark = pytest.mark.django_db @@ -1323,3 +1327,125 @@ def test_models_documents_compute_ancestors_links_paths_mapping_structure( {"link_reach": sibling.link_reach, "link_role": sibling.link_role}, ], } + + +@mock.patch.object(FindDocumentIndexer, "push") +@pytest.mark.django_db(transaction=True) +def test_models_documents_post_save_indexer(mock_push, settings): + """Test indexation task on document creation""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + user = factories.UserFactory() + + with transaction.atomic(): + doc1, doc2, doc3 = factories.DocumentFactory.create_batch(3) + + factories.UserDocumentAccessFactory(document=doc1, user=user) + factories.UserDocumentAccessFactory(document=doc2, user=user) + factories.UserDocumentAccessFactory(document=doc3, user=user) + + time.sleep(0.1) # waits for the end of the tasks + + accesses = { + str(doc1.path): {"users": [user.sub]}, + str(doc2.path): {"users": [user.sub]}, + str(doc3.path): {"users": [user.sub]}, + } + + data = [call.args[0] for call in mock_push.call_args_list] + + indexer = FindDocumentIndexer() + + def sortkey(d): + return d["id"] + + assert sorted(data, key=sortkey) == sorted( + [ + indexer.serialize_document(doc1, accesses), + indexer.serialize_document(doc2, accesses), + indexer.serialize_document(doc3, accesses), + ], + key=sortkey, + ) + + # The debounce counters should be reset + assert cache.get(document_indexer_debounce_key(doc1.pk)) == 0 + assert cache.get(document_indexer_debounce_key(doc2.pk)) == 0 + assert cache.get(document_indexer_debounce_key(doc3.pk)) == 0 + + +@pytest.mark.django_db(transaction=True) +def test_models_documents_post_save_indexer_debounce(settings): + """Test indexation task skipping on document update""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + indexer = FindDocumentIndexer() + user = factories.UserFactory() + + with mock.patch.object(FindDocumentIndexer, "push"): + with transaction.atomic(): + doc = factories.DocumentFactory() + factories.UserDocumentAccessFactory(document=doc, user=user) + + accesses = { + str(doc.path): {"users": [user.sub]}, + } + + time.sleep(0.1) # waits for the end of the tasks + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + # Simulate 1 waiting task + cache.set(document_indexer_debounce_key(doc.pk), 1) + + # save doc to trigger the indexer, but nothing should be done since + # the counter is over 0 + with transaction.atomic(): + doc.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [] + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + # No waiting task + cache.set(document_indexer_debounce_key(doc.pk), 0) + + with transaction.atomic(): + doc = models.Document.objects.get(pk=doc.pk) + doc.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [ + indexer.serialize_document(doc, accesses), + ] + + +@pytest.mark.django_db(transaction=True) +def test_models_documents_access_post_save_indexer(settings): + """Test indexation task on DocumentAccess update""" + settings.SEARCH_INDEXER_COUNTDOWN = 0 + + indexer = FindDocumentIndexer() + user = factories.UserFactory() + + with mock.patch.object(FindDocumentIndexer, "push"): + with transaction.atomic(): + doc = factories.DocumentFactory() + doc_access = factories.UserDocumentAccessFactory(document=doc, user=user) + + accesses = { + str(doc.path): {"users": [user.sub]}, + } + + indexer = FindDocumentIndexer() + + with mock.patch.object(FindDocumentIndexer, "push") as mock_push: + with transaction.atomic(): + doc_access.save() + + time.sleep(0.1) + + assert [call.args[0] for call in mock_push.call_args_list] == [ + indexer.serialize_document(doc, accesses), + ] diff --git a/src/backend/core/tests/test_services_search_indexers.py b/src/backend/core/tests/test_services_search_indexers.py new file mode 100644 index 0000000000..b9bffc914e --- /dev/null +++ b/src/backend/core/tests/test_services_search_indexers.py @@ -0,0 +1,339 @@ +"""Tests for Documents search indexers""" + +from functools import partial +from unittest.mock import patch + +import pytest + +from django.contrib.auth.models import AnonymousUser + +from core import factories, models, utils +from core.services.search_indexers import FindDocumentIndexer, get_visited_document_ids_of + +pytestmark = pytest.mark.django_db + + +def test_push_raises_error_if_search_indexer_url_is_none(settings): + """ + Indexer should raise RuntimeError if SEARCH_INDEXER_URL is None or empty. + """ + settings.SEARCH_INDEXER_URL = None + indexer = FindDocumentIndexer() + + with pytest.raises(RuntimeError) as exc_info: + indexer.push([]) + + assert "SEARCH_INDEXER_URL must be set in Django settings before indexing." in str( + exc_info.value + ) + + +def test_push_raises_error_if_search_indexer_url_is_empty(settings): + """ + Indexer should raise RuntimeError if SEARCH_INDEXER_URL is empty string. + """ + settings.SEARCH_INDEXER_URL = "" + indexer = FindDocumentIndexer() + + with pytest.raises(RuntimeError) as exc_info: + indexer.push([]) + + assert "SEARCH_INDEXER_URL must be set in Django settings before indexing." in str( + exc_info.value + ) + + +def test_push_raises_error_if_search_indexer_secret_is_none(settings): + """ + Indexer should raise RuntimeError if SEARCH_INDEXER_SECRET is None or empty. + """ + settings.SEARCH_INDEXER_SECRET = None + indexer = FindDocumentIndexer() + + with pytest.raises(RuntimeError) as exc_info: + indexer.push([]) + + assert ( + "SEARCH_INDEXER_SECRET must be set in Django settings before indexing." + in str(exc_info.value) + ) + + +def test_push_raises_error_if_search_indexer_secret_is_empty(settings): + """ + Indexer should raise RuntimeError if SEARCH_INDEXER_SECRET is empty string. + """ + settings.SEARCH_INDEXER_SECRET = "" + indexer = FindDocumentIndexer() + + with pytest.raises(RuntimeError) as exc_info: + indexer.push([]) + + assert ( + "SEARCH_INDEXER_SECRET must be set in Django settings before indexing." + in str(exc_info.value) + ) + + +def test_services_search_indexers_serialize_document_returns_expected_json(): + """ + It should serialize documents with correct metadata and access control. + """ + user_a, user_b = factories.UserFactory.create_batch(2) + document = factories.DocumentFactory() + factories.DocumentFactory(parent=document) + + factories.UserDocumentAccessFactory(document=document, user=user_a) + factories.UserDocumentAccessFactory(document=document, user=user_b) + factories.TeamDocumentAccessFactory(document=document, team="team1") + factories.TeamDocumentAccessFactory(document=document, team="team2") + + accesses = { + document.path: { + "users": {str(user_a.sub), str(user_b.sub)}, + "teams": {"team1", "team2"}, + } + } + + indexer = FindDocumentIndexer() + result = indexer.serialize_document(document, accesses) + + assert set(result.pop("users")) == {str(user_a.sub), str(user_b.sub)} + assert set(result.pop("groups")) == {"team1", "team2"} + assert result == { + "id": str(document.id), + "title": document.title, + "depth": 1, + "path": document.path, + "numchild": 1, + "content": utils.base64_yjs_to_text(document.content), + "created_at": document.created_at.isoformat(), + "updated_at": document.updated_at.isoformat(), + "reach": document.link_reach, + "size": 13, + "is_active": True, + } + + +def test_services_search_indexers_serialize_document_deleted(): + """Deleted documents are marked as just in the serialized json.""" + parent = factories.DocumentFactory() + document = factories.DocumentFactory(parent=parent) + + parent.soft_delete() + document.refresh_from_db() + + indexer = FindDocumentIndexer() + result = indexer.serialize_document(document, {}) + + assert result["is_active"] is False + + +def test_services_search_indexers_serialize_document_empty(): + """Empty documents returns empty content in the serialized json.""" + document = factories.DocumentFactory(content="", title=None) + + indexer = FindDocumentIndexer() + result = indexer.serialize_document(document, {}) + + assert result["content"] == "" + assert result["title"] == "" + + +@patch.object(FindDocumentIndexer, "push") +def test_services_search_indexers_batches_pass_only_batch_accesses(mock_push, settings): + """ + Documents indexing should be processed in batches, + and only the access data relevant to each batch should be used. + """ + settings.SEARCH_INDEXER_BATCH_SIZE = 2 + documents = factories.DocumentFactory.create_batch(5) + + # Attach a single user access to each document + expected_user_subs = {} + for document in documents: + access = factories.UserDocumentAccessFactory(document=document) + expected_user_subs[str(document.id)] = str(access.user.sub) + + FindDocumentIndexer().index() + + # Should be 3 batches: 2 + 2 + 1 + assert mock_push.call_count == 3 + + seen_doc_ids = set() + + for call in mock_push.call_args_list: + batch = call.args[0] + assert isinstance(batch, list) + + for doc_json in batch: + doc_id = doc_json["id"] + seen_doc_ids.add(doc_id) + + # Only one user expected per document + assert doc_json["users"] == [expected_user_subs[doc_id]] + assert doc_json["groups"] == [] + + # Make sure all 5 documents were indexed + assert seen_doc_ids == {str(d.id) for d in documents} + + +@patch.object(FindDocumentIndexer, "push") +def test_services_search_indexers_ancestors_link_reach(mock_push): + """Document accesses and reach should take into account ancestors link reaches.""" + great_grand_parent = factories.DocumentFactory(link_reach="restricted") + grand_parent = factories.DocumentFactory( + parent=great_grand_parent, link_reach="authenticated" + ) + parent = factories.DocumentFactory(parent=grand_parent, link_reach="public") + document = factories.DocumentFactory(parent=parent, link_reach="restricted") + + FindDocumentIndexer().index() + + seen_doc_ids = set() + results = {doc["id"]: doc for doc in mock_push.call_args[0][0]} + assert len(results) == 4 + assert results[str(great_grand_parent.id)]["reach"] == "restricted" + assert results[str(grand_parent.id)]["reach"] == "authenticated" + assert results[str(parent.id)]["reach"] == "public" + assert results[str(document.id)]["reach"] == "public" + + +@patch.object(FindDocumentIndexer, "push") +def test_services_search_indexers_ancestors_users(mock_push): + """Document accesses and reach should include users from ancestors.""" + user_gp, user_p, user_d = factories.UserFactory.create_batch(3) + + grand_parent = factories.DocumentFactory(users=[user_gp]) + parent = factories.DocumentFactory(parent=grand_parent, users=[user_p]) + document = factories.DocumentFactory(parent=parent, users=[user_d]) + + FindDocumentIndexer().index() + + seen_doc_ids = set() + results = {doc["id"]: doc for doc in mock_push.call_args[0][0]} + assert len(results) == 3 + assert results[str(grand_parent.id)]["users"] == [str(user_gp.sub)] + assert set(results[str(parent.id)]["users"]) == {str(user_gp.sub), str(user_p.sub)} + assert set(results[str(document.id)]["users"]) == { + str(user_gp.sub), + str(user_p.sub), + str(user_d.sub), + } + + +@patch.object(FindDocumentIndexer, "push") +def test_services_search_indexers_ancestors_teams(mock_push): + """Document accesses and reach should include teams from ancestors.""" + grand_parent = factories.DocumentFactory(teams=["team_gp"]) + parent = factories.DocumentFactory(parent=grand_parent, teams=["team_p"]) + document = factories.DocumentFactory(parent=parent, teams=["team_d"]) + + FindDocumentIndexer().index() + + seen_doc_ids = set() + results = {doc["id"]: doc for doc in mock_push.call_args[0][0]} + assert len(results) == 3 + assert results[str(grand_parent.id)]["groups"] == ["team_gp"] + assert set(results[str(parent.id)]["groups"]) == {"team_gp", "team_p"} + assert set(results[str(document.id)]["groups"]) == {"team_gp", "team_p", "team_d"} + + +@patch("requests.post") +def test_push_uses_correct_url_and_data(mock_post, settings): + """ + push() should call requests.post with the correct URL from settings + the timeout set to 10 seconds and the data as JSON. + """ + settings.SEARCH_INDEXER_URL = "http://example.com/index" + + indexer = FindDocumentIndexer() + sample_data = [{"id": "123", "title": "Test"}] + + mock_response = mock_post.return_value + mock_response.raise_for_status.return_value = None # No error + + indexer.push(sample_data) + + mock_post.assert_called_once() + args, kwargs = mock_post.call_args + + assert args[0] == settings.SEARCH_INDEXER_URL + assert kwargs.get("json") == sample_data + assert kwargs.get("timeout") == 10 + + +def test_get_visited_document_ids_of(): + """ + get_visited_document_ids_of() returns the ids of the documents viewed + by the user BUT without specific access configuration (like public ones) + """ + user = factories.UserFactory() + other = factories.UserFactory() + anonymous = AnonymousUser() + + assert get_visited_document_ids_of(anonymous) == [] + assert get_visited_document_ids_of(user) == [] + + doc1, doc2, _ = factories.DocumentFactory.create_batch(3) + + create_link = partial(models.LinkTrace.objects.create, user=user, is_masked=False) + + create_link(document=doc1) + create_link(document=doc2) + + # The third document is not visited + assert sorted(get_visited_document_ids_of(user)) == sorted([str(doc1.pk), str(doc2.pk)]) + + factories.UserDocumentAccessFactory(user=other, document=doc1) + factories.UserDocumentAccessFactory(user=user, document=doc2) + + # The second document have an access for the user + assert get_visited_document_ids_of(user) == [str(doc1.pk)] + + +@patch("requests.post") +def test_services_search_indexers_search(mock_post, settings): + user = factories.UserFactory() + indexer = FindDocumentIndexer() + + mock_response = mock_post.return_value + mock_response.raise_for_status.return_value = None # No error + + doc1, doc2, _ = factories.DocumentFactory.create_batch(3) + + create_link = partial(models.LinkTrace.objects.create, user=user, is_masked=False) + + create_link(document=doc1) + create_link(document=doc2) + + indexer.search('alpha', user=user, token='mytoken') + + args, kwargs = mock_post.call_args + + assert args[0] == settings.SEARCH_INDEXER_QUERY_URL + + query_data = kwargs.get("json") + assert query_data['q'] == 'alpha' + assert sorted(query_data['visited']) == sorted([str(doc1.pk), str(doc2.pk)]) + assert query_data['services'] == ['docs'] + + assert kwargs.get("headers") == {"Authorization": "Bearer mytoken"} + assert kwargs.get("timeout") == 10 + + +def test_search_query_raises_error_if_search_endpoint_is_none(settings): + """ + Indexer should raise RuntimeError if SEARCH_INDEXER_QUERY_URL is None or empty. + """ + settings.SEARCH_INDEXER_QUERY_URL = None + indexer = FindDocumentIndexer() + user = factories.UserFactory() + + with pytest.raises(RuntimeError) as exc_info: + indexer.search('alpha', user=user, token='mytoken') + + assert ( + "SEARCH_INDEXER_QUERY_URL must be set in Django settings before indexing." + in str(exc_info.value) + ) diff --git a/src/backend/core/tests/test_utils.py b/src/backend/core/tests/test_utils.py index 37b2e32d5e..42d588c536 100644 --- a/src/backend/core/tests/test_utils.py +++ b/src/backend/core/tests/test_utils.py @@ -75,3 +75,28 @@ def test_utils_extract_attachments(): base64_string = base64.b64encode(update).decode("utf-8") # image_key2 is missing the "/media/" part and shouldn't get extracted assert utils.extract_attachments(base64_string) == [image_key1, image_key3] + + +def test_utils_get_ancestor_to_descendants_map_single_path(): + """Test ancestor mapping of a single path.""" + paths = ["000100020005"] + result = utils.get_ancestor_to_descendants_map(paths, steplen=4) + + assert result == { + "0001": {"000100020005"}, + "00010002": {"000100020005"}, + "000100020005": {"000100020005"}, + } + + +def test_utils_get_ancestor_to_descendants_map_multiple_paths(): + """Test ancestor mapping of multiple paths with shared prefixes.""" + paths = ["000100020005", "00010003"] + result = utils.get_ancestor_to_descendants_map(paths, steplen=4) + + assert result == { + "0001": {"000100020005", "00010003"}, + "00010002": {"000100020005"}, + "000100020005": {"000100020005"}, + "00010003": {"00010003"}, + } diff --git a/src/backend/core/utils.py b/src/backend/core/utils.py index 780431f495..357ede03c3 100644 --- a/src/backend/core/utils.py +++ b/src/backend/core/utils.py @@ -2,6 +2,7 @@ import base64 import re +from collections import defaultdict import pycrdt from bs4 import BeautifulSoup @@ -9,6 +10,27 @@ from core import enums +def get_ancestor_to_descendants_map(paths, steplen): + """ + Given a list of document paths, return a mapping of ancestor_path -> set of descendant_paths. + + Each path is assumed to use materialized path format with fixed-length segments. + + Args: + paths (list of str): List of full document paths. + steplen (int): Length of each path segment. + + Returns: + dict[str, set[str]]: Mapping from ancestor path to its descendant paths (including itself). + """ + ancestor_map = defaultdict(set) + for path in paths: + for i in range(steplen, len(path) + 1, steplen): + ancestor = path[:i] + ancestor_map[ancestor].add(path) + return ancestor_map + + def filter_descendants(paths, root_paths, skip_sorting=False): """ Filters paths to keep only those that are descendants of any path in root_paths. diff --git a/src/backend/demo/management/commands/create_demo.py b/src/backend/demo/management/commands/create_demo.py index 74c0270922..b500f0b294 100644 --- a/src/backend/demo/management/commands/create_demo.py +++ b/src/backend/demo/management/commands/create_demo.py @@ -1,16 +1,19 @@ # ruff: noqa: S311, S106 """create_demo management command""" +import base64 import logging import math import random import time from collections import defaultdict +from uuid import uuid4 from django import db from django.conf import settings from django.core.management.base import BaseCommand, CommandError +import pycrdt from faker import Faker from core import models @@ -27,6 +30,16 @@ def random_true_with_probability(probability): return random.random() < probability +def get_ydoc_for_text(text): + """Return a ydoc from plain text for demo purposes.""" + ydoc = pycrdt.Doc() + paragraph = pycrdt.XmlElement("p", {}, [pycrdt.XmlText(text)]) + fragment = pycrdt.XmlFragment([paragraph]) + ydoc["document-store"] = fragment + update = ydoc.get_update() + return base64.b64encode(update).decode("utf-8") + + class BulkQueue: """A utility class to create Django model instances in bulk by just pushing to a queue.""" @@ -48,7 +61,7 @@ def _bulk_create(self, objects): self.queue[objects[0]._meta.model.__name__] = [] # noqa: SLF001 def push(self, obj): - """Add a model instance to queue to that it gets created in bulk.""" + """Add a model instance to queue so that it gets created in bulk.""" objects = self.queue[obj._meta.model.__name__] # noqa: SLF001 objects.append(obj) if len(objects) > self.BATCH_SIZE: @@ -139,17 +152,19 @@ def create_demo(stdout): # pylint: disable=protected-access key = models.Document._int2str(i) # noqa: SLF001 padding = models.Document.alphabet[0] * (models.Document.steplen - len(key)) - queue.push( - models.Document( - depth=1, - path=f"{padding}{key}", - creator_id=random.choice(users_ids), - title=fake.sentence(nb_words=4), - link_reach=models.LinkReachChoices.AUTHENTICATED - if random_true_with_probability(0.5) - else random.choice(models.LinkReachChoices.values), - ) + title = fake.sentence(nb_words=4) + document = models.Document( + id=uuid4(), + depth=1, + path=f"{padding}{key}", + creator_id=random.choice(users_ids), + title=title, + link_reach=models.LinkReachChoices.AUTHENTICATED + if random_true_with_probability(0.5) + else random.choice(models.LinkReachChoices.values), ) + document.save_content(get_ydoc_for_text(f"Content for {title:s}")) + queue.push(document) queue.flush() diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index 730574e3d2..ecb3b005c5 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -99,6 +99,20 @@ class Base(Configuration): } DEFAULT_AUTO_FIELD = "django.db.models.AutoField" + # Search + SEARCH_INDEXER_BATCH_SIZE = values.IntegerValue( + default=100_000, environ_name="SEARCH_INDEXER_BATCH_SIZE", environ_prefix=None + ) + SEARCH_INDEXER_URL = values.Value( + default=None, environ_name="SEARCH_INDEXER_URL", environ_prefix=None + ) + SEARCH_INDEXER_SECRET = values.Value( + default=None, environ_name="SEARCH_INDEXER_SECRET", environ_prefix=None + ) + SEARCH_INDEXER_QUERY_URL = values.Value( + default=None, environ_name="SEARCH_INDEXER_QUERY_URL", environ_prefix=None + ) + # Static files (CSS, JavaScript, Images) STATIC_URL = "/static/" STATIC_ROOT = os.path.join(DATA_DIR, "static")