Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions features/content_publish.feature
Original file line number Diff line number Diff line change
Expand Up @@ -2741,7 +2741,7 @@ Feature: Content Publishing
"related--1": {
"_id": "text",
"type": "text",
"state": "__no_value__"
"state": "published"
}
}
}]}
Expand All @@ -2759,7 +2759,7 @@ Feature: Content Publishing
"related--1": {
"_id": "text",
"type": "text",
"state": "__no_value__"
"state": "published"
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
"xmlsec>=1.3.13,<1.3.15",
# Async libraries
"motor>=3.4.0,<4.0",
"pydantic>=2.7.4,<3.0",
# There's a breaking change in 2.13 that's affecting our usage, so pinning to 2.12
"pydantic>=2.7.4,<2.13",
# Custom repos, with patches applied
"eve @ git+https://github.com/superdesk/eve@async",
"eve-elastic @ git+https://github.com/superdesk/eve-elastic@async",
Expand Down
11 changes: 10 additions & 1 deletion superdesk/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from superdesk import json_utils
from superdesk.logging import logger
from superdesk.flask import Flask
from superdesk.core.tasks import run_in_thread


class SuperdeskMangler(hermes.Mangler):
Expand Down Expand Up @@ -63,6 +64,9 @@ def init_app(self, app):
password=parsed_url.password if parsed_url.password else None,
port=int(parsed_url.port) if parsed_url.port else 6379,
db=int(parsed_url.path[1:]) if parsed_url.path else 0,
socket_timeout=app.config.get("CACHE_REDIS_TIMEOUT", 10),
socket_connect_timeout=app.config.get("CACHE_REDIS_CONNECT_TIMEOUT", 2),
retry_on_timeout=True,
)
logger.info("using redis cache backend")
return
Expand Down Expand Up @@ -109,5 +113,10 @@ def clean(self):
return self._backend.clean()


class SuperdeskCache(hermes.Hermes):
def clean_in_thread(self, tags: list[str]) -> None:
run_in_thread(self.clean, tags)
Comment thread
MarkLark86 marked this conversation as resolved.
Outdated


cache_backend = SuperdeskCacheBackend(SuperdeskMangler())
cache = hermes.Hermes(cache_backend, mangler=cache_backend.mangler, ttl=600)
cache = SuperdeskCache(cache_backend, mangler=cache_backend.mangler, ttl=600)
59 changes: 59 additions & 0 deletions superdesk/core/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Any, Callable
import asyncio
import logging

logger = logging.getLogger(__name__)
_thread_tasks: set[asyncio.Task] = set()


def run_in_thread(func: Callable, *args: Any, **kwargs: Any) -> asyncio.Task:
"""Run a callable in a thread pool without awaiting it."""

# ``asyncio.to_thread`` copies the current context, so Quart/Flask functionality should still work
task_name = f"thread_task__{func.__module__}_{func.__qualname__}"
coroutine = asyncio.to_thread(func, *args, **kwargs)
task = asyncio.create_task(coroutine, name=task_name)
_thread_tasks.add(task)
task.add_done_callback(_handle_background_task_result)
return task
Comment thread
MarkLark86 marked this conversation as resolved.
Outdated


async def wait_thread_tasks_to_complete(timeout: float = 10) -> None:
"""Wait for all background tasks to complete

:param timeout: The maximum time to wait for the tasks to complete, in seconds.
"""

if not _thread_tasks:
return

# 1. Signal cancellation to all tasks
for task in _thread_tasks:
Comment thread
MarkLark86 marked this conversation as resolved.
Outdated
task.cancel()

try:
# 2. Wrap the gather in a wait_for to enforce the timeout
await asyncio.wait_for(
asyncio.gather(*_thread_tasks, return_exceptions=True),
timeout,
)
except asyncio.TimeoutError:
# 3. Handle tasks that refused to stop in time
still_running = [t for t in _thread_tasks if not t.done()]
logger.warning(f"Background threads shutdown timed out. {len(still_running)} tasks still active.")
finally:
# 4. Clear the set to release references
_thread_tasks.clear()
Comment thread
MarkLark86 marked this conversation as resolved.
Outdated


def _handle_background_task_result(task: asyncio.Task[Any]) -> None:
task_name = task.get_name()

try:
task.result()
except asyncio.CancelledError:
logger.warning("Background task was cancelled", extra={"task_name": task_name})
except Exception:
logger.exception("Background task failed", extra={"task_name": task_name})
finally:
_thread_tasks.discard(task)
12 changes: 12 additions & 0 deletions superdesk/default_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,18 @@ def local_to_utc_hour(hour):
#: cache type - set explicit cache type if it wouldn't get it right from url
CACHE_TYPE = env("SUPERDESK_CACHE_TYPE")

#: Cache Redis socket timeout
#:
#: .. versionadded:: 3.5
#:
CACHE_REDIS_TIMEOUT = float(env("CACHE_REDIS_TIMEOUT", 10))

#: Cache Redis socket connection timeout
#:
#: .. versionadded:: 3.5
#:
CACHE_REDIS_CONNECT_TIMEOUT = float(env("CACHE_REDIS_CONNECT_TIMEOUT", 2))

#: celery broker
BROKER_URL = env("CELERY_BROKER_URL", REDIS_URL)
CELERY_BROKER_URL = BROKER_URL
Expand Down
25 changes: 15 additions & 10 deletions superdesk/eve_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from eve.methods.common import resolve_document_etag
from elasticsearch.exceptions import RequestError, NotFoundError

import superdesk
from superdesk.core import json, get_app_config, get_current_app
from superdesk.resource_fields import ID_FIELD, ETAG, LAST_UPDATED, DATE_CREATED
from superdesk.errors import SuperdeskApiError
Expand Down Expand Up @@ -273,7 +274,7 @@ def find_and_modify(self, endpoint_name, **kwargs):
kwargs["query"] = backend._mongotize(kwargs["query"], endpoint_name)

result = backend.driver.db[endpoint_name].find_one_and_update(**kwargs)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return result

async def find_and_modify_async(self, endpoint_name, **kwargs):
Expand All @@ -288,7 +289,7 @@ async def find_and_modify_async(self, endpoint_name, **kwargs):
kwargs["query"] = backend._mongotize(kwargs["query"], endpoint_name)

result = await backend.driver.db[endpoint_name].find_one_and_update(**kwargs)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return result

def get_mongo_collection(self, endpoint_name) -> pymongo.collection.Collection:
Expand All @@ -311,7 +312,7 @@ def create(self, endpoint_name, docs, **kwargs):
doc.pop("_type", None)
ids = self.create_in_mongo(endpoint_name, docs, **kwargs)
self.create_in_search(endpoint_name, docs, **kwargs)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)

for doc in docs:
self._push_resource_notification("created", endpoint_name, _id=str(doc["_id"]))
Expand All @@ -328,7 +329,7 @@ async def create_async(self, endpoint_name, docs, **kwargs):
doc.pop("_type", None)
ids = await self.create_in_mongo_async(endpoint_name, docs, **kwargs)
await self.create_in_search_async(endpoint_name, docs, **kwargs)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)

for doc in docs:
self._push_resource_notification("created", endpoint_name, _id=str(doc["_id"]))
Expand Down Expand Up @@ -512,7 +513,7 @@ def _change_request(self, endpoint_name, id, updates, original, change_request=F
logger.warning("Item is missing in elastic resource=%s id=%s", endpoint_name, id)
search_backend.insert(endpoint_name, [doc])

cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return updates

async def _change_request_async(
Expand Down Expand Up @@ -565,7 +566,7 @@ async def _change_request_async(
logger.warning("Item is missing in elastic resource=%s id=%s", endpoint_name, id)
await search_backend.insert(endpoint_name, [doc])

cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return updates

def replace(self, endpoint_name, id, document, original):
Expand All @@ -578,7 +579,7 @@ def replace(self, endpoint_name, id, document, original):
"""
res = self.replace_in_mongo(endpoint_name, id, document, original)
self.replace_in_search(endpoint_name, id, document, original)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)

# with soft delete enabled eve uses replace to update the document
if document.get("_deleted") and not original.get("_deleted"):
Expand All @@ -600,7 +601,7 @@ async def replace_async(self, endpoint_name, id, document, original):
"""
res = await self.replace_in_mongo_async(endpoint_name, id, document, original)
await self.replace_in_search_async(endpoint_name, id, document, original)
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)

# with soft delete enabled eve uses replace to update the document
if document.get("_deleted") and not original.get("_deleted"):
Expand Down Expand Up @@ -720,7 +721,7 @@ def delete(self, endpoint_name, lookup):
removed_ids = [lookup["_id"]]
except NotFoundError:
pass # not found in elastic and not in mongo
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return removed_ids

async def delete_async(self, endpoint_name, lookup):
Expand All @@ -745,7 +746,7 @@ async def delete_async(self, endpoint_name, lookup):
removed_ids = [lookup["_id"]]
except NotFoundError:
pass # not found in elastic and not in mongo
cache.clean([endpoint_name])
self._clean_cache(endpoint_name)
return removed_ids

def delete_docs(self, endpoint_name, docs):
Expand Down Expand Up @@ -895,6 +896,10 @@ def _lookup_backend(self, endpoint_name, fallback=False, use_async: bool = False
backend = app.data._backend(endpoint_name, use_async)
return backend

def _clean_cache(self, endpoint_name: str) -> None:
if getattr(superdesk.get_resource_service(endpoint_name), "uses_cache", False):
cache.clean_in_thread([endpoint_name])
Comment thread
MarkLark86 marked this conversation as resolved.

def set_default_dates(self, doc):
"""Helper to populate ``_created`` and ``_updated`` timestamps."""
now = utcnow()
Expand Down
4 changes: 4 additions & 0 deletions superdesk/factory/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
from superdesk.core.resources import ResourceRestEndpoints, ResourceConfig
from superdesk.core.resources.validators import convert_pydantic_validation_error_for_response
from superdesk.core.web import NullEndpoint
from superdesk.core.tasks import wait_thread_tasks_to_complete

SUPERDESK_PATH = os.path.abspath(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

Expand Down Expand Up @@ -254,6 +255,9 @@ def __init__(self, **kwargs):

self.on_get_api_root += self.extend_eve_home_endpoint

# Wait for background tasks to complete when gracefully shutting down
self.after_serving_funcs.append(wait_thread_tasks_to_complete)

def __getattr__(self, name):
"""Only use events for on_* methods."""
if name.startswith("on_"):
Expand Down
1 change: 1 addition & 0 deletions superdesk/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ class Service(BaseService):
class CacheableService(BaseService):
"""Handles caching for the resource, will invalidate on any changes to the resource."""

uses_cache = True
datasource: str
cache_lookup = {}

Expand Down
89 changes: 89 additions & 0 deletions tests/core/tasks_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import asyncio
import time
from unittest.mock import patch
from superdesk.core import tasks
from superdesk.tests import AsyncTestCase


class TasksTestCase(AsyncTestCase):
async def asyncTearDown(self):
# Clean up any remaining tasks
await tasks.wait_thread_tasks_to_complete(timeout=1)
await super().asyncTearDown()

async def test_run_in_thread(self):
"""Test run_in_thread correctly executes the function and returns a task."""

def test_func(arg1, kwarg1=None):
self.assertEqual(arg1, "foo")
self.assertEqual(kwarg1, "bar")
return "success"

task = tasks.run_in_thread(test_func, "foo", kwarg1="bar")

self.assertIsInstance(task, asyncio.Task)
self.assertIn(task, tasks._thread_tasks)
self.assertTrue(task.get_name().startswith("thread_task__"))

# Wait for the task itself to finish and its callback to run
await task
result = task.result()
self.assertEqual(result, "success")

# # The task should be removed from _thread_tasks by the callback
self.assertNotIn(task, tasks._thread_tasks)

async def test_wait_thread_tasks_to_complete(self):
"""Test wait_thread_tasks_to_complete cancels and waits for tasks."""

def long_running_func():
# This will be interrupted if the thread is cancelled,
time.sleep(5)
raise Exception("This should not be reached")

with patch("superdesk.core.tasks.logger") as mock_logger:
task = tasks.run_in_thread(long_running_func)
self.assertIn(task, tasks._thread_tasks)

await tasks.wait_thread_tasks_to_complete(timeout=1)

mock_logger.warning.assert_called_with(
"Background task was cancelled", extra={"task_name": task.get_name()}
)

Comment thread
MarkLark86 marked this conversation as resolved.
Outdated
self.assertTrue(task.done())
self.assertTrue(task.cancelled())
self.assertEqual(len(tasks._thread_tasks), 0)
Comment thread
MarkLark86 marked this conversation as resolved.
Outdated

async def test_handle_background_task_result_exception(self):
"""Test _handle_background_task_result logs exceptions."""

def failing_func():
raise ValueError("test exception")

with patch("superdesk.core.tasks.logger") as mock_logger:
task = tasks.run_in_thread(failing_func)

with self.assertRaises(ValueError):
await task

mock_logger.exception.assert_called_with("Background task failed", extra={"task_name": task.get_name()})
self.assertNotIn(task, tasks._thread_tasks)

async def test_handle_background_task_result_cancelled(self):
"""Test _handle_background_task_result logs cancellation."""

def slow_func():
time.sleep(0.5)

with patch("superdesk.core.tasks.logger") as mock_logger:
task = tasks.run_in_thread(slow_func)
task.cancel()

with self.assertRaises(asyncio.CancelledError):
await task

mock_logger.warning.assert_called_with(
"Background task was cancelled", extra={"task_name": task.get_name()}
)
self.assertNotIn(task, tasks._thread_tasks)
Loading