diff --git a/.github/workflows/post_release_to_hacker_news.yml b/.github/workflows/post_release_to_hacker_news.yml deleted file mode 100644 index c21287558bd..00000000000 --- a/.github/workflows/post_release_to_hacker_news.yml +++ /dev/null @@ -1,18 +0,0 @@ -on: - release: - types: [released] - -permissions: {} -jobs: - post_release_to_hacker_news: - runs-on: ubuntu-latest - name: Post Release to Hacker News - steps: - - name: Post the Release - uses: MicahLyle/github-action-post-to-hacker-news@v1 - env: - HN_USERNAME: ${{ secrets.HN_USERNAME }} - HN_PASSWORD: ${{ secrets.HN_PASSWORD }} - HN_TITLE_FORMAT_SPECIFIER: Celery v%s Released! - HN_URL_FORMAT_SPECIFIER: https://docs.celeryq.dev/en/v%s/changelog.html - HN_TEST_MODE: true diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index df76966793a..7a30911874f 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -29,10 +29,10 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.7', 'pypy-3.8'] + python-version: ['3.7', '3.8', '3.9', '3.10', 'pypy-3.9', 'pypy-3.8'] os: ["ubuntu-latest", "windows-latest"] exclude: - - python-version: 'pypy-3.7' + - python-version: 'pypy-3.9' os: "windows-latest" - python-version: 'pypy-3.8' os: "windows-latest" @@ -120,7 +120,7 @@ jobs: run: | echo "::set-output name=dir::$(pip cache dir)" - name: Install tox - run: python -m pip install tox + run: python -m pip install --upgrade pip tox tox-gh-actions - name: > Run tox for "${{ matrix.python-version }}-integration-${{ matrix.toxenv }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f91e4309713..16d19389cbc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/asottile/pyupgrade - rev: v3.2.0 + rev: v3.3.1 hooks: - id: pyupgrade args: ["--py37-plus"] - repo: https://github.com/PyCQA/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 @@ -16,7 +16,7 @@ repos: - id: yesqa - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-merge-conflict - id: check-toml @@ -24,12 +24,12 @@ repos: - id: mixed-line-ending - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: v5.11.3 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.982 + rev: v0.991 hooks: - id: mypy pass_filenames: false diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 4b99f190dbe..e8c1dec868b 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -290,3 +290,5 @@ Gabor Boros, 2021/11/09 Tizian Seehaus, 2022/02/09 Oleh Romanovskyi, 2022/06/09 JoonHwan Kim, 2022/08/01 +Kaustav Banerjee, 2022/11/10 +Austin Snoeyink 2022/12/06 diff --git a/celery/__init__.py b/celery/__init__.py index 7c2de763898..aa64b596f0a 100644 --- a/celery/__init__.py +++ b/celery/__init__.py @@ -70,8 +70,7 @@ def debug_import(name, locals=None, globals=None, from celery.app.base import Celery from celery.app.task import Task from celery.app.utils import bugreport - from celery.canvas import (chain, chord, chunks, group, maybe_signature, signature, subtask, xmap, # noqa - xstarmap) + from celery.canvas import chain, chord, chunks, group, maybe_signature, signature, subtask, xmap, xstarmap from celery.utils import uuid # Eventlet/gevent patching must happen before importing diff --git a/celery/app/amqp.py b/celery/app/amqp.py index e3245811035..9e52af4a66f 100644 --- a/celery/app/amqp.py +++ b/celery/app/amqp.py @@ -321,7 +321,7 @@ def as_task_v2(self, task_id, name, args=None, kwargs=None, if not root_id: # empty root_id defaults to task_id root_id = task_id - stamps = {header: maybe_list(options[header]) for header in stamped_headers or []} + stamps = {header: options[header] for header in stamped_headers or []} headers = { 'lang': 'py', 'task': name, diff --git a/celery/app/base.py b/celery/app/base.py index 6ca3eaf5ada..73ddf4e0f7d 100644 --- a/celery/app/base.py +++ b/celery/app/base.py @@ -33,8 +33,7 @@ from celery.utils.time import maybe_make_aware, timezone, to_utc # Load all builtin tasks -from . import builtins # noqa -from . import backends +from . import backends, builtins from .annotations import prepare as prepare_annotations from .autoretry import add_autoretry_behaviour from .defaults import DEFAULT_SECURITY_DIGEST, find_deprecated_settings @@ -458,6 +457,9 @@ def cons(app): sum([len(args), len(opts)]))) return inner_create_task_cls(**opts) + def type_checker(self, fun, bound=False): + return staticmethod(head_from_fun(fun, bound=bound)) + def _task_from_fun(self, fun, name=None, base=None, bind=False, **options): if not self.finalized and not self.autofinalize: raise RuntimeError('Contract breach: app not finalized') @@ -474,7 +476,7 @@ def _task_from_fun(self, fun, name=None, base=None, bind=False, **options): '__doc__': fun.__doc__, '__module__': fun.__module__, '__annotations__': fun.__annotations__, - '__header__': staticmethod(head_from_fun(fun, bound=bind)), + '__header__': self.type_checker(fun, bound=bind), '__wrapped__': run}, **options))() # for some reason __qualname__ cannot be set in type() # so we have to set it here. @@ -778,6 +780,10 @@ def send_task(self, name, args=None, kwargs=None, countdown=None, **options ) + stamped_headers = options.pop('stamped_headers', []) + for stamp in stamped_headers: + options.pop(stamp) + if connection: producer = amqp.Producer(connection, auto_declare=False) diff --git a/celery/app/defaults.py b/celery/app/defaults.py index ce8d0ae1a90..a9f68689940 100644 --- a/celery/app/defaults.py +++ b/celery/app/defaults.py @@ -78,6 +78,7 @@ def __repr__(self): scheduler=Option('celery.beat:PersistentScheduler'), schedule_filename=Option('celerybeat-schedule'), sync_every=Option(0, type='int'), + cron_starting_deadline=Option(None, type=int) ), broker=Namespace( url=Option(None, type='string'), @@ -89,6 +90,7 @@ def __repr__(self): connection_retry=Option(True, type='bool'), connection_retry_on_startup=Option(None, type='bool'), connection_max_retries=Option(100, type='int'), + channel_error_retry=Option(False, type='bool'), failover_strategy=Option(None, type='string'), heartbeat=Option(120, type='int'), heartbeat_checkrate=Option(3.0, type='int'), diff --git a/celery/app/task.py b/celery/app/task.py index 22794fd16de..c2d9784da33 100644 --- a/celery/app/task.py +++ b/celery/app/task.py @@ -953,10 +953,18 @@ def replace(self, sig): for t in reversed(self.request.chain or []): sig |= signature(t, app=self.app) # Stamping sig with parents groups - stamped_headers = self.request.stamped_headers if self.request.stamps: groups = self.request.stamps.get("groups") - sig.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers)) + sig.stamp(visitor=GroupStampingVisitor(groups=groups, stamped_headers=self.request.stamped_headers)) + stamped_headers = self.request.stamped_headers.copy() + stamps = self.request.stamps.copy() + stamped_headers.extend(sig.options.get('stamped_headers', [])) + stamps.update({ + stamp: value + for stamp, value in sig.options.items() if stamp in sig.options.get('stamped_headers', []) + }) + sig.options['stamped_headers'] = stamped_headers + sig.options.update(stamps) return self.on_replace(sig) diff --git a/celery/app/trace.py b/celery/app/trace.py index 5307620d342..37eb57ef591 100644 --- a/celery/app/trace.py +++ b/celery/app/trace.py @@ -10,7 +10,7 @@ from collections import namedtuple from warnings import warn -from billiard.einfo import ExceptionInfo +from billiard.einfo import ExceptionInfo, ExceptionWithTraceback from kombu.exceptions import EncodeError from kombu.serialization import loads as loads_message from kombu.serialization import prepare_accept_content @@ -238,6 +238,8 @@ def handle_failure(self, task, req, store_errors=True, call_errbacks=True): def _log_error(self, task, req, einfo): eobj = einfo.exception = get_pickled_exception(einfo.exception) + if isinstance(eobj, ExceptionWithTraceback): + eobj = einfo.exception = eobj.exc exception, traceback, exc_info, sargs, skwargs = ( safe_repr(eobj), safe_str(einfo.traceback), diff --git a/celery/backends/base.py b/celery/backends/base.py index e851c8189f6..a8bf01a5929 100644 --- a/celery/backends/base.py +++ b/celery/backends/base.py @@ -397,7 +397,7 @@ def exception_to_python(self, exc): exc = cls(*exc_msg) else: exc = cls(exc_msg) - except Exception as err: # noqa + except Exception as err: exc = Exception(f'{cls}({exc_msg})') return exc @@ -817,11 +817,25 @@ class BaseKeyValueStoreBackend(Backend): def __init__(self, *args, **kwargs): if hasattr(self.key_t, '__func__'): # pragma: no cover self.key_t = self.key_t.__func__ # remove binding - self._encode_prefixes() super().__init__(*args, **kwargs) + self._add_global_keyprefix() + self._encode_prefixes() if self.implements_incr: self.apply_chord = self._apply_chord_incr + def _add_global_keyprefix(self): + """ + This method prepends the global keyprefix to the existing keyprefixes. + + This method checks if a global keyprefix is configured in `result_backend_transport_options` using the + `global_keyprefix` key. If so, then it is prepended to the task, group and chord key prefixes. + """ + global_keyprefix = self.app.conf.get('result_backend_transport_options', {}).get("global_keyprefix", None) + if global_keyprefix: + self.task_keyprefix = f"{global_keyprefix}_{self.task_keyprefix}" + self.group_keyprefix = f"{global_keyprefix}_{self.group_keyprefix}" + self.chord_keyprefix = f"{global_keyprefix}_{self.chord_keyprefix}" + def _encode_prefixes(self): self.task_keyprefix = self.key_t(self.task_keyprefix) self.group_keyprefix = self.key_t(self.group_keyprefix) diff --git a/celery/beat.py b/celery/beat.py index 4c9486532e3..a3d13adafb3 100644 --- a/celery/beat.py +++ b/celery/beat.py @@ -46,7 +46,7 @@ class SchedulingError(Exception): class BeatLazyFunc: - """An lazy function declared in 'beat_schedule' and called before sending to worker. + """A lazy function declared in 'beat_schedule' and called before sending to worker. Example: diff --git a/celery/bin/shell.py b/celery/bin/shell.py index 77b14d8a307..840bcc3c52f 100644 --- a/celery/bin/shell.py +++ b/celery/bin/shell.py @@ -67,10 +67,10 @@ def _no_ipython(self): # pragma: no cover def _invoke_default_shell(locals): try: - import IPython # noqa + import IPython except ImportError: try: - import bpython # noqa + import bpython except ImportError: _invoke_fallback_shell(locals) else: diff --git a/celery/canvas.py b/celery/canvas.py index 3d09d1879c5..aadd39003f5 100644 --- a/celery/canvas.py +++ b/celery/canvas.py @@ -92,7 +92,7 @@ def _merge_dictionaries(d1, d2): else: if isinstance(value, (int, float, str)): d1[key] = [value] - if isinstance(d2[key], list): + if isinstance(d2[key], list) and d1[key] is not None: d1[key].extend(d2[key]) else: if d1[key] is None: @@ -161,7 +161,6 @@ def on_signature(self, sig, **headers) -> dict: Returns: Dict: headers to update. """ - pass def on_chord_header_start(self, chord, **header) -> dict: """Method that is called on сhord header stamping start. @@ -311,6 +310,12 @@ class Signature(dict): @classmethod def register_type(cls, name=None): + """Register a new type of signature. + Used as a class decorator, for example: + >>> @Signature.register_type() + >>> class mysig(Signature): + >>> pass + """ def _inner(subclass): cls.TYPES[name or subclass.__name__] = subclass return subclass @@ -319,6 +324,10 @@ def _inner(subclass): @classmethod def from_dict(cls, d, app=None): + """Create a new signature from a dict. + Subclasses can override this method to customize how are + they created from a dict. + """ typ = d.get('subtask_type') if typ: target_cls = cls.TYPES[typ] @@ -413,6 +422,24 @@ def apply_async(self, args=None, kwargs=None, route_name=None, **options): return _apply(args, kwargs, **options) def _merge(self, args=None, kwargs=None, options=None, force=False): + """Merge partial args/kwargs/options with existing ones. + + If the signature is immutable and ``force`` is False, the existing + args/kwargs will be returned as-is and only the options will be merged. + + Stamped headers are considered immutable and will not be merged regardless. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + options (Dict): Partial options to be merged with existing options. + force (bool): If True, the args/kwargs will be merged even if the signature is + immutable. The stamped headers are not affected by this option and will not + be merged regardless. + + Returns: + Tuple: (args, kwargs, options) + """ args = args if args else () kwargs = kwargs if kwargs else {} if options is not None: @@ -423,6 +450,7 @@ def _merge(self, args=None, kwargs=None, options=None, force=False): immutable_options = self._IMMUTABLE_OPTIONS if "stamped_headers" in self.options: immutable_options = self._IMMUTABLE_OPTIONS.union(set(self.options["stamped_headers"])) + # merge self.options with options without overriding stamped headers from self.options new_options = {**self.options, **{ k: v for k, v in options.items() if k not in immutable_options or k not in self.options @@ -471,6 +499,18 @@ def freeze(self, _id=None, group_id=None, chord=None, twice after freezing it as that'll result in two task messages using the same task id. + The arguments are used to override the signature's headers during + freezing. + + Arguments: + _id (str): Task id to use if it didn't already have one. + New UUID is generated if not provided. + group_id (str): Group id to use if it didn't already have one. + chord (Signature): Chord body when freezing a chord header. + root_id (str): Root id to use. + parent_id (str): Parent id to use. + group_index (int): Group index to use. + Returns: ~@AsyncResult: promise of future evaluation. """ @@ -594,18 +634,34 @@ def stamp_links(self, visitor, **headers): link.stamp(visitor=visitor, **headers) def _with_list_option(self, key): + """Gets the value at the given self.options[key] as a list. + + If the value is not a list, it will be converted to one and saved in self.options. + If the key does not exist, an empty list will be set and returned instead. + + Arguments: + key (str): The key to get the value for. + + Returns: + List: The value at the given key as a list or an empty list if the key does not exist. + """ items = self.options.setdefault(key, []) if not isinstance(items, MutableSequence): items = self.options[key] = [items] return items def append_to_list_option(self, key, value): + """Appends the given value to the list at the given key in self.options.""" items = self._with_list_option(key) if value not in items: items.append(value) return value def extend_list_option(self, key, value): + """Extends the list at the given key in self.options with the given value. + + If the value is not a list, it will be converted to one. + """ items = self._with_list_option(key) items.extend(maybe_list(value)) @@ -652,6 +708,14 @@ def flatten_links(self): ))) def __or__(self, other): + """Chaining operator. + + Example: + >>> add.s(2, 2) | add.s(4) | add.s(8) + + Returns: + chain: Constructs a :class:`~celery.canvas.chain` of the given signatures. + """ if isinstance(other, _chain): # task | chain -> chain return _chain(seq_concat_seq( @@ -685,6 +749,16 @@ def election(self): return type.AsyncResult(tid) def reprcall(self, *args, **kwargs): + """Return a string representation of the signature. + + Merges the given arguments with the signature's arguments + only for the purpose of generating the string representation. + The signature itself is not modified. + + Example: + >>> add.s(2, 2).reprcall() + 'add(2, 2)' + """ args, kwargs, _ = self._merge(args, kwargs, {}, force=True) return reprcall(self['task'], args, kwargs) @@ -841,6 +915,10 @@ def __or__(self, other): if not tasks: # If the chain is empty, return the group return other + if isinstance(tasks[-1], chord): + # CHAIN [last item is chord] | GROUP -> chain with chord body. + tasks[-1].body = tasks[-1].body | other + return type(self)(tasks, app=self.app) # use type(self) for _chain subclasses return type(self)(seq_concat_item( tasks, other), app=self._app) @@ -879,9 +957,19 @@ def clone(self, *args, **kwargs): return signature def unchain_tasks(self): + """Return a list of tasks in the chain. + + The tasks list would be cloned from the chain's tasks. + All of the chain callbacks would be added to the last task in the (cloned) chain. + All of the tasks would be linked to the same error callback + as the chain itself, to ensure that the correct error callback is called + if any of the (cloned) tasks of the chain fail. + """ # Clone chain's tasks assigning signatures from link_error - # to each task + # to each task and adding the chain's links to the last task. tasks = [t.clone() for t in self.tasks] + for sig in self.options.get('link', []): + tasks[-1].link(sig) for sig in self.options.get('link_error', []): for task in tasks: task.link_error(sig) @@ -903,6 +991,12 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, task_id=None, link=None, link_error=None, publisher=None, producer=None, root_id=None, parent_id=None, app=None, group_index=None, **options): + """Executes the chain. + + Responsible for executing the chain in the correct order. + In a case of a chain of a single task, the task is executed directly + and the result is returned for that task specifically. + """ # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. args = args if args else () @@ -914,6 +1008,7 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, args = (tuple(args) + tuple(self.args) if args and not self.immutable else self.args) + # Unpack nested chains/groups/chords tasks, results_from_prepare = self.prepare_steps( args, kwargs, self.tasks, root_id, parent_id, link_error, app, task_id, group_id, chord, group_index=group_index, @@ -924,6 +1019,8 @@ def run(self, args=None, kwargs=None, group_id=None, chord=None, visitor = GroupStampingVisitor(groups=groups, stamped_headers=stamped_headers) self.stamp(visitor=visitor) + # For a chain of single task, execute the task directly and return the result for that task + # For a chain of multiple tasks, execute all of the tasks and return the AsyncResult for the chain if results_from_prepare: if link: tasks[0].extend_list_option('link', link) @@ -971,6 +1068,38 @@ def prepare_steps(self, args, kwargs, tasks, last_task_id=None, group_id=None, chord_body=None, clone=True, from_dict=Signature.from_dict, group_index=None): + """Prepare the chain for execution. + + To execute a chain, we first need to unpack it correctly. + During the unpacking, we might encounter other chains, groups, or chords + which we need to unpack as well. + + For example: + chain(signature1, chain(signature2, signature3)) --> Upgrades to chain(signature1, signature2, signature3) + chain(group(signature1, signature2), signature3) --> Upgrades to chord([signature1, signature2], signature3) + + The responsibility of this method is to ensure that the chain is + correctly unpacked, and then the correct callbacks are set up along the way. + + Arguments: + args (Tuple): Partial args to be prepended to the existing args. + kwargs (Dict): Partial kwargs to be merged with existing kwargs. + tasks (List[Signature]): The tasks of the chain. + root_id (str): The id of the root task. + parent_id (str): The id of the parent task. + link_error (Union[List[Signature], Signature]): The error callback. + will be set for all tasks in the chain. + app (Celery): The Celery app instance. + last_task_id (str): The id of the last task in the chain. + group_id (str): The id of the group that the chain is a part of. + chord_body (Signature): The body of the chord, used to syncronize with the chain's + last task and the chord's body when used together. + clone (bool): Whether to clone the chain's tasks before modifying them. + from_dict (Callable): A function that takes a dict and returns a Signature. + + Returns: + Tuple[List[Signature], List[AsyncResult]]: The frozen tasks of the chain, and the async results + """ app = app or self.app # use chain message field for protocol 2 and later. # this avoids pickle blowing the stack on the recursion @@ -1325,6 +1454,50 @@ class group(Signature): @classmethod def from_dict(cls, d, app=None): + """Create a group signature from a dictionary that represents a group. + + Example: + >>> group_dict = { + "task": "celery.group", + "args": [], + "kwargs": { + "tasks": [ + { + "task": "add", + "args": [ + 1, + 2 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + }, + { + "task": "add", + "args": [ + 3, + 4 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + ] + }, + "options": {}, + "subtask_type": "group", + "immutable": False + } + >>> group_sig = group.from_dict(group_dict) + + Iterates over the given tasks in the dictionary and convert them to signatures. + Tasks needs to be defined in d['kwargs']['tasks'] as a sequence + of tasks. + + The tasks themselves can be dictionaries or signatures (or both). + """ # We need to mutate the `kwargs` element in place to avoid confusing # `freeze()` implementations which end up here and expect to be able to # access elements from that dictionary later and refer to objects @@ -1343,6 +1516,8 @@ def __init__(self, *tasks, **options): if isinstance(tasks, abstract.CallableSignature): tasks = [tasks.clone()] if not isinstance(tasks, _regen): + # May potentially cause slow downs when using a + # generator of many tasks - Issue #6973 tasks = regen(tasks) super().__init__('celery.group', (), {'tasks': tasks}, **options ) @@ -1356,6 +1531,7 @@ def __or__(self, other): return chord(self, body=other, app=self._app) def skew(self, start=1.0, stop=None, step=1.0): + # TODO: Not sure if this is still used anywhere (besides its own tests). Consider removing. it = fxrange(start, stop, step, repeatlast=True) for task in self.tasks: task.set(countdown=next(it)) @@ -1468,6 +1644,32 @@ def _prepared(self, tasks, partial_args, group_id, root_id, app, CallableSignature=abstract.CallableSignature, from_dict=Signature.from_dict, isinstance=isinstance, tuple=tuple): + """Recursively unroll the group into a generator of its tasks. + + This is used by :meth:`apply_async` and :meth:`apply` to + unroll the group into a list of tasks that can be evaluated. + + Note: + This does not change the group itself, it only returns + a generator of the tasks that the group would evaluate to. + + Arguments: + tasks (list): List of tasks in the group (may contain nested groups). + partial_args (list): List of arguments to be prepended to + the arguments of each task. + group_id (str): The group id of the group. + root_id (str): The root id of the group. + app (Celery): The Celery app instance. + CallableSignature (class): The signature class of the group's tasks. + from_dict (fun): Function to create a signature from a dict. + isinstance (fun): Function to check if an object is an instance + of a class. + tuple (class): A tuple-like class. + + Returns: + generator: A generator for the unrolled group tasks. + The generator yields tuples of the form ``(task, AsyncResult, group_id)``. + """ for task in tasks: if isinstance(task, CallableSignature): # local sigs are always of type Signature, and we @@ -1490,6 +1692,25 @@ def _prepared(self, tasks, partial_args, group_id, root_id, app, def _apply_tasks(self, tasks, producer=None, app=None, p=None, add_to_parent=None, chord=None, args=None, kwargs=None, **options): + """Run all the tasks in the group. + + This is used by :meth:`apply_async` to run all the tasks in the group + and return a generator of their results. + + Arguments: + tasks (list): List of tasks in the group. + producer (Producer): The producer to use to publish the tasks. + app (Celery): The Celery app instance. + p (barrier): Barrier object to synchronize the tasks results. + args (list): List of arguments to be prepended to + the arguments of each task. + kwargs (dict): Dict of keyword arguments to be merged with + the keyword arguments of each task. + **options (dict): Options to be merged with the options of each task. + + Returns: + generator: A generator for the AsyncResult of the tasks in the group. + """ # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. app = app or self.app @@ -1533,6 +1754,7 @@ def _apply_tasks(self, tasks, producer=None, app=None, p=None, yield res # <-- r.parent, etc set in the frozen result. def _freeze_gid(self, options): + """Freeze the group id by the existing task_id or a new UUID.""" # remove task_id and use that as the group_id, # if we don't remove it then every task will have the same id... options = {**self.options, **{ @@ -1545,6 +1767,15 @@ def _freeze_gid(self, options): def _freeze_group_tasks(self, _id=None, group_id=None, chord=None, root_id=None, parent_id=None, group_index=None): + """Freeze the tasks in the group. + + Note: + If the group tasks are created from a generator, the tasks generator would + not be exhausted, and the tasks would be frozen lazily. + + Returns: + tuple: A tuple of the group id, and the AsyncResult of each of the group tasks. + """ # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. opts = self.options @@ -1561,15 +1792,16 @@ def _freeze_group_tasks(self, _id=None, group_id=None, chord=None, root_id = opts.setdefault('root_id', root_id) parent_id = opts.setdefault('parent_id', parent_id) if isinstance(self.tasks, _regen): - # We are draining from a generator here. - # tasks1, tasks2 are each a clone of self.tasks + # When the group tasks are a generator, we need to make sure we don't + # exhaust it during the freeze process. We use two generators to do this. + # One generator will be used to freeze the tasks to get their AsyncResult. + # The second generator will be used to replace the tasks in the group with an unexhausted state. + + # Create two new generators from the original generator of the group tasks (cloning the tasks). tasks1, tasks2 = itertools.tee(self._unroll_tasks(self.tasks)) - # freeze each task in tasks1, results now holds AsyncResult for each task + # Use the first generator to freeze the group tasks to acquire the AsyncResult for each task. results = regen(self._freeze_tasks(tasks1, group_id, chord, root_id, parent_id)) - # TODO figure out why this makes sense - - # we freeze all tasks in the clone tasks1, and then zip the results - # with the IDs of tasks in the second clone, tasks2. and then, we build - # a generator that takes only the task IDs from tasks2. + # Use the second generator to replace the exhausted generator of the group tasks. self.tasks = regen(tasks2) else: new_tasks = [] @@ -1594,6 +1826,7 @@ def freeze(self, _id=None, group_id=None, chord=None, _freeze = freeze def _freeze_tasks(self, tasks, group_id, chord, root_id, parent_id): + """Creates a generator for the AsyncResult of each task in the tasks argument.""" yield from (task.freeze(group_id=group_id, chord=chord, root_id=root_id, @@ -1602,10 +1835,29 @@ def _freeze_tasks(self, tasks, group_id, chord, root_id, parent_id): for group_index, task in enumerate(tasks)) def _unroll_tasks(self, tasks): + """Creates a generator for the cloned tasks of the tasks argument.""" # should be refactored to: (maybe_signature(task, app=self._app, clone=True) for task in tasks) yield from (maybe_signature(task, app=self._app).clone() for task in tasks) def _freeze_unroll(self, new_tasks, group_id, chord, root_id, parent_id): + """Generator for the frozen flattened group tasks. + + Creates a flattened list of the tasks in the group, and freezes + each task in the group. Nested groups will be recursively flattened. + + Exhausting the generator will create a new list of the flattened + tasks in the group and will return it in the new_tasks argument. + + Arguments: + new_tasks (list): The list to append the flattened tasks to. + group_id (str): The group_id to use for the tasks. + chord (Chord): The chord to use for the tasks. + root_id (str): The root_id to use for the tasks. + parent_id (str): The parent_id to use for the tasks. + + Yields: + AsyncResult: The frozen task. + """ # pylint: disable=redefined-outer-name # XXX chord is also a class in outer scope. stack = deque(self.tasks) @@ -1674,6 +1926,60 @@ class _chord(Signature): @classmethod def from_dict(cls, d, app=None): + """Create a chord signature from a dictionary that represents a chord. + + Example: + >>> chord_dict = { + "task": "celery.chord", + "args": [], + "kwargs": { + "kwargs": {}, + "header": [ + { + "task": "add", + "args": [ + 1, + 2 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + }, + { + "task": "add", + "args": [ + 3, + 4 + ], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + ], + "body": { + "task": "xsum", + "args": [], + "kwargs": {}, + "options": {}, + "subtask_type": None, + "immutable": False + } + }, + "options": {}, + "subtask_type": "chord", + "immutable": False + } + >>> chord_sig = chord.from_dict(chord_dict) + + Iterates over the given tasks in the dictionary and convert them to signatures. + Chord header needs to be defined in d['kwargs']['header'] as a sequence + of tasks. + Chord body needs to be defined in d['kwargs']['body'] as a single task. + + The tasks themselves can be dictionaries or signatures (or both). + """ options = d.copy() args, options['kwargs'] = cls._unpack_args(**options['kwargs']) return cls(*args, app=app, **options) @@ -1704,6 +2010,13 @@ def __or__(self, other): sig = self.clone() sig.body = sig.body | other return sig + elif isinstance(other, group) and len(other.tasks) == 1: + # chord | group -> chain with chord body. + # unroll group with one member + other = maybe_unroll_group(other) + sig = self.clone() + sig.body = sig.body | other + return sig else: return super().__or__(other) @@ -1811,6 +2124,10 @@ def apply(self, args=None, kwargs=None, @classmethod def _descend(cls, sig_obj): + """Count the number of tasks in the given signature recursively. + + Descend into the signature object and return the amount of tasks it contains. + """ # Sometimes serialized signatures might make their way here if not isinstance(sig_obj, Signature) and isinstance(sig_obj, dict): sig_obj = Signature.from_dict(sig_obj) @@ -1837,12 +2154,34 @@ def _descend(cls, sig_obj): return len(sig_obj) def __length_hint__(self): + """Return the number of tasks in this chord's header (recursively).""" tasks = getattr(self.tasks, "tasks", self.tasks) return sum(self._descend(task) for task in tasks) def run(self, header, body, partial_args, app=None, interval=None, countdown=1, max_retries=None, eager=False, task_id=None, kwargs=None, **options): + """Execute the chord. + + Executing the chord means executing the header and sending the + result to the body. In case of an empty header, the body is + executed immediately. + + Arguments: + header (group): The header to execute. + body (Signature): The body to execute. + partial_args (tuple): Arguments to pass to the header. + app (Celery): The Celery app instance. + interval (float): The interval between retries. + countdown (int): The countdown between retries. + max_retries (int): The maximum number of retries. + task_id (str): The task id to use for the body. + kwargs (dict): Keyword arguments to pass to the header. + options (dict): Options to pass to the header. + + Returns: + AsyncResult: The result of the body (with the result of the header in the parent of the body). + """ app = app or self._get_app(body) group_id = header.options.get('task_id') or uuid() root_id = body.options.get('root_id') @@ -1894,18 +2233,28 @@ def clone(self, *args, **kwargs): return signature def link(self, callback): + """Links a callback to the chord body only.""" self.body.link(callback) return callback def link_error(self, errback): + """Links an error callback to the chord body, and potentially the header as well. + + Note: + The ``task_allow_error_cb_on_chord_header`` setting controls whether + error callbacks are allowed on the header. If this setting is + ``False`` (the current default), then the error callback will only be + applied to the body. + """ if self.app.conf.task_allow_error_cb_on_chord_header: - # self.tasks can be a list of the chord header workflow. - if isinstance(self.tasks, (list, tuple)): - for task in self.tasks: - task.link_error(errback) - else: - self.tasks.link_error(errback) + for task in self.tasks: + task.link_error(errback) else: + # Once this warning is removed, the whole method needs to be refactored to: + # 1. link the error callback to each task in the header + # 2. link the error callback to the body + # 3. return the error callback + # In summary, up to 4 lines of code + updating the method docstring. warnings.warn( "task_allow_error_cb_on_chord_header=False is pending deprecation in " "a future release of Celery.\n" @@ -1919,7 +2268,14 @@ def link_error(self, errback): return errback def set_immutable(self, immutable): - # changes mutability of header only, not callback. + """Sets the immutable flag on the chord header only. + + Note: + Does not affect the chord body. + + Arguments: + immutable (bool): The new mutability value for chord header. + """ for task in self.tasks: task.set_immutable(immutable) diff --git a/celery/concurrency/__init__.py b/celery/concurrency/__init__.py index a326c79aff2..54eabfa2543 100644 --- a/celery/concurrency/__init__.py +++ b/celery/concurrency/__init__.py @@ -1,4 +1,5 @@ """Pool implementation abstract factory, and alias definitions.""" +import os # Import from kombu directly as it's used # early in the import stage, where celery.utils loads @@ -16,11 +17,25 @@ } try: - import concurrent.futures # noqa: F401 + import concurrent.futures except ImportError: pass else: ALIASES['threads'] = 'celery.concurrency.thread:TaskPool' +# +# Allow for an out-of-tree worker pool implementation. This is used as follows: +# +# - Set the environment variable CELERY_CUSTOM_WORKER_POOL to the name of +# an implementation of :class:`celery.concurrency.base.BasePool` in the +# standard Celery format of "package:class". +# - Select this pool using '--pool custom'. +# +try: + custom = os.environ.get('CELERY_CUSTOM_WORKER_POOL') +except KeyError: + pass +else: + ALIASES['custom'] = custom def get_implementation(cls): diff --git a/celery/concurrency/asynpool.py b/celery/concurrency/asynpool.py index 19715005828..b735e7b1014 100644 --- a/celery/concurrency/asynpool.py +++ b/celery/concurrency/asynpool.py @@ -57,7 +57,7 @@ def __read__(fd, buf, size, read=os.read): return n readcanbuf = False - def unpack_from(fmt, iobuf, unpack=unpack): # noqa + def unpack_from(fmt, iobuf, unpack=unpack): return unpack(fmt, iobuf.getvalue()) # <-- BytesIO __all__ = ('AsynPool',) diff --git a/celery/concurrency/base.py b/celery/concurrency/base.py index 0b4db3fbf35..1ce9a751ea2 100644 --- a/celery/concurrency/base.py +++ b/celery/concurrency/base.py @@ -3,6 +3,7 @@ import os import sys import time +from typing import Any, Dict from billiard.einfo import ExceptionInfo from billiard.exceptions import WorkerLostError @@ -154,8 +155,15 @@ def apply_async(self, target, args=None, kwargs=None, **options): callbacks_propagate=self.callbacks_propagate, **options) - def _get_info(self): + def _get_info(self) -> Dict[str, Any]: + """ + Return configuration and statistics information. Subclasses should + augment the data as required. + + :return: The returned value must be JSON-friendly. + """ return { + 'implementation': self.__class__.__module__ + ':' + self.__class__.__name__, 'max-concurrency': self.limit, } diff --git a/celery/concurrency/prefork.py b/celery/concurrency/prefork.py index 40772ebae1a..b163328d0b3 100644 --- a/celery/concurrency/prefork.py +++ b/celery/concurrency/prefork.py @@ -155,7 +155,8 @@ def on_close(self): def _get_info(self): write_stats = getattr(self._pool, 'human_write_stats', None) - return { + info = super()._get_info() + info.update({ 'max-concurrency': self.limit, 'processes': [p.pid for p in self._pool._pool], 'max-tasks-per-child': self._pool._maxtasksperchild or 'N/A', @@ -163,7 +164,8 @@ def _get_info(self): 'timeouts': (self._pool.soft_timeout or 0, self._pool.timeout or 0), 'writes': write_stats() if write_stats is not None else 'N/A', - } + }) + return info @property def num_processes(self): diff --git a/celery/concurrency/solo.py b/celery/concurrency/solo.py index ea6e274a3ba..e7e9c7f3ba4 100644 --- a/celery/concurrency/solo.py +++ b/celery/concurrency/solo.py @@ -20,10 +20,12 @@ def __init__(self, *args, **kwargs): signals.worker_process_init.send(sender=None) def _get_info(self): - return { + info = super()._get_info() + info.update({ 'max-concurrency': 1, 'processes': [os.getpid()], 'max-tasks-per-child': None, 'put-guarded-by-semaphore': True, 'timeouts': (), - } + }) + return info diff --git a/celery/concurrency/thread.py b/celery/concurrency/thread.py index 120374bcf9b..b9c23e0173a 100644 --- a/celery/concurrency/thread.py +++ b/celery/concurrency/thread.py @@ -61,7 +61,9 @@ def on_apply( return ApplyResult(f) def _get_info(self) -> PoolInfo: - return { + info = super()._get_info() + info.update({ 'max-concurrency': self.limit, 'threads': len(self.executor._threads) - } + }) + return info diff --git a/celery/contrib/testing/app.py b/celery/contrib/testing/app.py index 95ed700b8ec..b8bd9f0d77a 100644 --- a/celery/contrib/testing/app.py +++ b/celery/contrib/testing/app.py @@ -47,7 +47,7 @@ def __init__(self, *args, **kwargs): def TestApp(name=None, config=None, enable_logging=False, set_as_current=False, log=UnitLogging, backend=None, broker=None, **kwargs): """App used for testing.""" - from . import tasks # noqa + from . import tasks config = dict(deepcopy(DEFAULT_TEST_CONFIG), **config or {}) if broker is not None: config.pop('broker_url', None) diff --git a/celery/schedules.py b/celery/schedules.py index 62940132098..9798579754f 100644 --- a/celery/schedules.py +++ b/celery/schedules.py @@ -36,7 +36,6 @@ {0._orig_day_of_week} (m/h/dM/MY/d)>\ """ - SOLAR_INVALID_LATITUDE = """\ Argument latitude {lat} is invalid, must be between -90 and 90.\ """ @@ -608,16 +607,48 @@ def remaining_estimate(self, last_run_at, ffwd=ffwd): def is_due(self, last_run_at): """Return tuple of ``(is_due, next_time_to_run)``. + If :setting:`beat_cron_starting_deadline` has been specified, the + scheduler will make sure that the `last_run_at` time is within the + deadline. This prevents tasks that could have been run according to + the crontab, but didn't, from running again unexpectedly. + Note: Next time to run is in seconds. SeeAlso: :meth:`celery.schedules.schedule.is_due` for more information. """ + rem_delta = self.remaining_estimate(last_run_at) - rem = max(rem_delta.total_seconds(), 0) + rem_secs = rem_delta.total_seconds() + rem = max(rem_secs, 0) due = rem == 0 - if due: + + deadline_secs = self.app.conf.beat_cron_starting_deadline + has_passed_deadline = False + if deadline_secs is not None: + # Make sure we're looking at the latest possible feasible run + # date when checking the deadline. + last_date_checked = last_run_at + last_feasible_rem_secs = rem_secs + while rem_secs < 0: + last_date_checked = last_date_checked + abs(rem_delta) + rem_delta = self.remaining_estimate(last_date_checked) + rem_secs = rem_delta.total_seconds() + if rem_secs < 0: + last_feasible_rem_secs = rem_secs + + # if rem_secs becomes 0 or positive, second-to-last + # last_date_checked must be the last feasible run date. + # Check if the last feasible date is within the deadline + # for running + has_passed_deadline = -last_feasible_rem_secs > deadline_secs + if has_passed_deadline: + # Should not be due if we've passed the deadline for looking + # at past runs + due = False + + if due or has_passed_deadline: rem_delta = self.remaining_estimate(self.now()) rem = max(rem_delta.total_seconds(), 0) return schedstate(due, rem) diff --git a/celery/security/__init__.py b/celery/security/__init__.py index c801d98b1df..cea3c2ff78f 100644 --- a/celery/security/__init__.py +++ b/celery/security/__init__.py @@ -36,7 +36,7 @@ __all__ = ('setup_security',) try: - import cryptography # noqa + import cryptography except ImportError: raise ImproperlyConfigured(CRYPTOGRAPHY_NOT_INSTALLED) diff --git a/celery/worker/consumer/consumer.py b/celery/worker/consumer/consumer.py index 6dd93ba7e57..d70dc179c78 100644 --- a/celery/worker/consumer/consumer.py +++ b/celery/worker/consumer/consumer.py @@ -124,7 +124,7 @@ These tasks cannot be acknowledged as the connection is gone, and the tasks are automatically redelivered back to the queue. You can enable this behavior using the worker_cancel_long_running_tasks_on_connection_loss setting. In Celery 5.1 it is set to False by default. The setting will be set to True by default in Celery 6.0. -""" # noqa: E501 +""" def dump_body(m, body): @@ -328,9 +328,13 @@ def start(self): crit('Frequent restarts detected: %r', exc, exc_info=1) sleep(1) self.restart_count += 1 + if self.app.conf.broker_channel_error_retry: + recoverable_errors = (self.connection_errors + self.channel_errors) + else: + recoverable_errors = self.connection_errors try: blueprint.start(self) - except self.connection_errors as exc: + except recoverable_errors as exc: # If we're not retrying connections, we need to properly shutdown or terminate # the Celery main process instead of abruptly aborting the process without any cleanup. is_connection_loss_on_startup = self.restart_count == 0 diff --git a/celery/worker/request.py b/celery/worker/request.py index b409bdc60da..ff8020a6f0f 100644 --- a/celery/worker/request.py +++ b/celery/worker/request.py @@ -327,7 +327,7 @@ def stamped_headers(self) -> list: @property def stamps(self) -> dict: - return {header: self._request_dict[header] for header in self.stamped_headers} + return {header: self._request_dict['stamps'][header] for header in self.stamped_headers} @property def correlation_id(self): diff --git a/celery/worker/state.py b/celery/worker/state.py index 74b28d4397e..1c7ab3942fa 100644 --- a/celery/worker/state.py +++ b/celery/worker/state.py @@ -103,11 +103,13 @@ def task_reserved(request, def task_accepted(request, _all_total_count=None, + add_request=requests.__setitem__, add_active_request=active_requests.add, add_to_total_count=total_count.update): """Update global state when a task has been accepted.""" if not _all_total_count: _all_total_count = all_total_count + add_request(request.id, request) add_active_request(request) add_to_total_count({request.name: 1}) all_total_count[0] += 1 diff --git a/docs/getting-started/backends-and-brokers/redis.rst b/docs/getting-started/backends-and-brokers/redis.rst index 9d42397de57..1c583f0bb27 100644 --- a/docs/getting-started/backends-and-brokers/redis.rst +++ b/docs/getting-started/backends-and-brokers/redis.rst @@ -100,6 +100,24 @@ If you are using Sentinel, you should specify the master_name using the :setting app.conf.result_backend_transport_options = {'master_name': "mymaster"} +.. _redis-result-backend-global-keyprefix: + +Global keyprefix +^^^^^^^^^^^^^^^^ + +The global key prefix will be prepended to all keys used for the result backend, +which can be useful when a redis database is shared by different users. +By default, no prefix is prepended. + +To configure the global keyprefix for the Redis result backend, use the ``global_keyprefix`` key under :setting:`result_backend_transport_options`: + + +.. code-block:: python + + app.conf.result_backend_transport_options = { + 'global_keyprefix': 'my_prefix_' + } + .. _redis-result-backend-timeout: Connection timeouts diff --git a/docs/userguide/configuration.rst b/docs/userguide/configuration.rst index 5350d9fa2af..fbc22200cbd 100644 --- a/docs/userguide/configuration.rst +++ b/docs/userguide/configuration.rst @@ -2806,6 +2806,19 @@ to the AMQP broker. If this is set to :const:`0` or :const:`None`, we'll retry forever. +``broker_channel_error_retry`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 5.3 + +Default: Disabled. + +Automatically try to re-establish the connection to the AMQP broker +if any invalid response has been returned. + +The retry count and interval is the same as that of `broker_connection_retry`. +Also, this option doesn't work when `broker_connection_retry` is `False`. + .. setting:: broker_login_method ``broker_login_method`` @@ -3495,3 +3508,16 @@ changes to the schedule into account. Also when running Celery beat embedded (:option:`-B `) on Jython as a thread the max interval is overridden and set to 1 so that it's possible to shut down in a timely manner. + +.. setting:: beat_cron_starting_deadline + +``beat_cron_starting_deadline`` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. versionadded:: 5.3 + +Default: None. + +When using cron, the number of seconds :mod:`~celery.bin.beat` can look back +when deciding whether a cron schedule is due. When set to `None`, cronjobs that +are past due will always run immediately. diff --git a/docs/userguide/tasks.rst b/docs/userguide/tasks.rst index 16a73ec6e79..6f9ceed528f 100644 --- a/docs/userguide/tasks.rst +++ b/docs/userguide/tasks.rst @@ -1432,9 +1432,11 @@ The above can be added to each task like this: .. code-block:: python - @app.task(base=DatabaseTask) - def process_rows(): - for row in process_rows.db.table.all(): + from celery.app import task + + @app.task(base=DatabaseTask, bind=True) + def process_rows(self: task): + for row in self.db.table.all(): process_row(row) The ``db`` attribute of the ``process_rows`` task will then diff --git a/examples/celery_http_gateway/urls.py b/examples/celery_http_gateway/urls.py index 802ff2344b2..7b74284c137 100644 --- a/examples/celery_http_gateway/urls.py +++ b/examples/celery_http_gateway/urls.py @@ -1,5 +1,5 @@ from celery_http_gateway.tasks import hello_world -from django.conf.urls.defaults import handler404, handler500, include, patterns, url # noqa +from django.conf.urls.defaults import handler404, handler500, include, patterns, url from djcelery import views as celery_views # Uncomment the next two lines to enable the admin: diff --git a/examples/django/proj/urls.py b/examples/django/proj/urls.py index 5f67c27b660..74415c35830 100644 --- a/examples/django/proj/urls.py +++ b/examples/django/proj/urls.py @@ -1,4 +1,4 @@ -from django.urls import handler404, handler500, include, url # noqa +from django.urls import handler404, handler500, include, url # Uncomment the next two lines to enable the admin: # from django.contrib import admin diff --git a/examples/django/requirements.txt b/examples/django/requirements.txt index 4ba37fb5b8a..ef6d5a6de00 100644 --- a/examples/django/requirements.txt +++ b/examples/django/requirements.txt @@ -1,3 +1,3 @@ django>=2.2.1 -sqlalchemy>=1.0.14 +sqlalchemy>=1.2.18 celery>=5.0.5 diff --git a/examples/stamping/config.py b/examples/stamping/config.py new file mode 100644 index 00000000000..e3d8869ad9c --- /dev/null +++ b/examples/stamping/config.py @@ -0,0 +1,7 @@ +from celery import Celery + +app = Celery( + 'myapp', + broker='redis://', + backend='redis://', +) diff --git a/examples/stamping/myapp.py b/examples/stamping/myapp.py new file mode 100644 index 00000000000..92e68b2cb45 --- /dev/null +++ b/examples/stamping/myapp.py @@ -0,0 +1,52 @@ +"""myapp.py + +This is a simple example of how to use the stamping feature. +It uses a custom stamping visitor to stamp a workflow with a unique +monitoring id stamp (per task), and a different visitor to stamp the last +task in the workflow. The last task is stamped with a consistent stamp, which +is used to revoke the task by its stamped header using two different approaches: +1. Run the workflow, then revoke the last task by its stamped header. +2. Revoke the last task by its stamped header before running the workflow. + +Usage:: + + # The worker service reacts to messages by executing tasks. + (window1)$ celery -A myapp worker -l INFO + + # The shell service is used to run the example. + (window2)$ celery -A myapp shell + + # Use (copy) the content of shell.py to run the workflow via the + # shell service. + + # Use one of two demo runs via the shell service: + # 1) run_then_revoke(): Run the workflow and revoke the last task + # by its stamped header during its run. + # 2) revoke_then_run(): Revoke the last task by its stamped header + # before its run, then run the workflow. + # + # See worker logs for output per defined in task_received_handler(). +""" +import json + +# Import tasks in worker context +import tasks +from config import app + +from celery.signals import task_received + + +@task_received.connect +def task_received_handler( + sender=None, + request=None, + signal=None, + **kwargs +): + print(f'In {signal.name} for: {repr(request)}') + print(f'Found stamps: {request.stamped_headers}') + print(json.dumps(request.stamps, indent=4, sort_keys=True)) + + +if __name__ == '__main__': + app.start() diff --git a/examples/stamping/shell.py b/examples/stamping/shell.py new file mode 100644 index 00000000000..3d2b48bb1a3 --- /dev/null +++ b/examples/stamping/shell.py @@ -0,0 +1,75 @@ +from time import sleep + +from tasks import identity, mul, wait_for_revoke, xsum +from visitors import MonitoringIdStampingVisitor + +from celery.canvas import Signature, chain, chord, group +from celery.result import AsyncResult + + +def create_canvas(n: int) -> Signature: + """Creates a canvas to calculate: n * sum(1..n) * 10 + For example, if n = 3, the result is 3 * (1 + 2 + 3) * 10 = 180 + """ + canvas = chain( + group(identity.s(i) for i in range(1, n+1)) | xsum.s(), + chord(group(mul.s(10) for _ in range(1, n+1)), xsum.s()), + ) + + return canvas + + +def revoke_by_headers(result: AsyncResult, terminate: bool) -> None: + """Revokes the last task in the workflow by its stamped header + + Arguments: + result (AsyncResult): Can be either a frozen or a running result + terminate (bool): If True, the revoked task will be terminated + """ + result.revoke_by_stamped_headers({'mystamp': 'I am a stamp!'}, terminate=terminate) + + +def prepare_workflow() -> Signature: + """Creates a canvas that waits "n * sum(1..n) * 10" in seconds, + with n = 3. + + The canvas itself is stamped with a unique monitoring id stamp per task. + The waiting task is stamped with different consistent stamp, which is used + to revoke the task by its stamped header. + """ + canvas = create_canvas(n=3) + canvas = canvas | wait_for_revoke.s() + canvas.stamp(MonitoringIdStampingVisitor()) + return canvas + + +def run_then_revoke(): + """Runs the workflow and lets the waiting task run for a while. + Then, the waiting task is revoked by its stamped header. + + The expected outcome is that the canvas will be calculated to the end, + but the waiting task will be revoked and terminated *during its run*. + + See worker logs for more details. + """ + canvas = prepare_workflow() + result = canvas.delay() + print('Wait 5 seconds, then revoke the last task by its stamped header: "mystamp": "I am a stamp!"') + sleep(5) + print('Revoking the last task...') + revoke_by_headers(result, terminate=True) + + +def revoke_then_run(): + """Revokes the waiting task by its stamped header before it runs. + Then, run the workflow, which will not run the waiting task that was revoked. + + The expected outcome is that the canvas will be calculated to the end, + but the waiting task will not run at all. + + See worker logs for more details. + """ + canvas = prepare_workflow() + result = canvas.freeze() + revoke_by_headers(result, terminate=False) + result = canvas.delay() diff --git a/examples/stamping/tasks.py b/examples/stamping/tasks.py new file mode 100644 index 00000000000..0cb3e113809 --- /dev/null +++ b/examples/stamping/tasks.py @@ -0,0 +1,48 @@ +from time import sleep + +from config import app + +from celery import Task +from examples.stamping.visitors import MyStampingVisitor + + +class MyTask(Task): + """Custom task for stamping on replace""" + + def on_replace(self, sig): + sig.stamp(MyStampingVisitor()) + return super().on_replace(sig) + + +@app.task +def identity(x): + """Identity function""" + return x + + +@app.task +def mul(x: int, y: int) -> int: + """Multiply two numbers""" + return x * y + + +@app.task +def xsum(numbers: list) -> int: + """Sum a list of numbers""" + return sum(numbers) + + +@app.task +def waitfor(seconds: int) -> None: + """Wait for "seconds" seconds, ticking every second.""" + print(f'Waiting for {seconds} seconds...') + for i in range(seconds): + sleep(1) + print(f'{i+1} seconds passed') + + +@app.task(bind=True, base=MyTask) +def wait_for_revoke(self: MyTask, seconds: int) -> None: + """Replace this task with a new task that waits for "seconds" seconds.""" + # This will stamp waitfor with MyStampingVisitor + self.replace(waitfor.s(seconds)) diff --git a/examples/stamping/visitors.py b/examples/stamping/visitors.py new file mode 100644 index 00000000000..0b7e462014f --- /dev/null +++ b/examples/stamping/visitors.py @@ -0,0 +1,14 @@ +from uuid import uuid4 + +from celery.canvas import StampingVisitor + + +class MyStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'mystamp': 'I am a stamp!'} + + +class MonitoringIdStampingVisitor(StampingVisitor): + + def on_signature(self, sig, **headers) -> dict: + return {'monitoring_id': str(uuid4())} diff --git a/requirements/default.txt b/requirements/default.txt index 34f4c77b685..f159c7bce7f 100644 --- a/requirements/default.txt +++ b/requirements/default.txt @@ -1,5 +1,5 @@ pytz>=2021.3 -billiard>=4.0.2,<5.0 +billiard>=4.1.0,<5.0 kombu>=5.3.0b2,<6.0 vine>=5.0.0,<6.0 click>=8.1.2,<9.0 diff --git a/requirements/dev.txt b/requirements/dev.txt index fbc54e32a4e..b6425608a53 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -3,4 +3,4 @@ git+https://github.com/celery/py-amqp.git git+https://github.com/celery/kombu.git git+https://github.com/celery/billiard.git vine>=5.0.0 -isort~=5.10.1 +isort==5.11.3 diff --git a/requirements/docs.txt b/requirements/docs.txt index cdb836b29cd..d4704e0364e 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,7 +1,7 @@ sphinx_celery~=2.0.0 Sphinx>=3.0.0 sphinx-testing~=1.0.1 -sphinx-click==4.3.0 +sphinx-click==4.4.0 -r extras/sqlalchemy.txt -r test.txt -r deps/mock.txt diff --git a/requirements/extras/auth.txt b/requirements/extras/auth.txt index 388c40441b4..2a81f1cb11e 100644 --- a/requirements/extras/auth.txt +++ b/requirements/extras/auth.txt @@ -1 +1 @@ -cryptography==38.0.3 +cryptography==38.0.4 diff --git a/requirements/extras/sqlalchemy.txt b/requirements/extras/sqlalchemy.txt index 0f2e8f033eb..8e2b106495c 100644 --- a/requirements/extras/sqlalchemy.txt +++ b/requirements/extras/sqlalchemy.txt @@ -1 +1 @@ -sqlalchemy~=1.4.34 +sqlalchemy==1.4.45 diff --git a/requirements/test.txt b/requirements/test.txt index 9e6362c6ab1..cb4b7bf0d4c 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -7,7 +7,7 @@ pytest-order==1.0.1 boto3>=1.9.178 moto>=2.2.6 # typing extensions -mypy==0.982; platform_python_implementation=="CPython" +mypy==0.991; platform_python_implementation=="CPython" pre-commit==2.20.0 -r extras/yaml.txt -r extras/msgpack.txt diff --git a/t/integration/tasks.py b/t/integration/tasks.py index 64f9512f4b6..dac9455c38e 100644 --- a/t/integration/tasks.py +++ b/t/integration/tasks.py @@ -2,6 +2,7 @@ from time import sleep from celery import Signature, Task, chain, chord, group, shared_task +from celery.canvas import StampingVisitor, signature from celery.exceptions import SoftTimeLimitExceeded from celery.utils.log import get_task_logger @@ -25,6 +26,12 @@ def add(x, y, z=None): return x + y +@shared_task +def mul(x: int, y: int) -> int: + """Multiply two numbers""" + return x * y + + @shared_task def write_to_file_and_return_int(file_name, i): with open(file_name, mode='a', buffering=1) as file_handle: @@ -421,3 +428,30 @@ def errback_old_style(request_id): def errback_new_style(request, exc, tb): redis_count(request.id) return request.id + + +class StampOnReplace(StampingVisitor): + stamp = {'StampOnReplace': 'This is the replaced task'} + + def on_signature(self, sig, **headers) -> dict: + return self.stamp + + +class StampedTaskOnReplace(Task): + """Custom task for stamping on replace""" + + def on_replace(self, sig): + sig.stamp(StampOnReplace()) + return super().on_replace(sig) + + +@shared_task +def replaced_with_me(): + return True + + +@shared_task(bind=True, base=StampedTaskOnReplace) +def replace_with_stamped_task(self: StampedTaskOnReplace, replace_with=None): + if replace_with is None: + replace_with = replaced_with_me.s() + self.replace(signature(replace_with)) diff --git a/t/integration/test_canvas.py b/t/integration/test_canvas.py index 8e805db49b7..47150bfb79e 100644 --- a/t/integration/test_canvas.py +++ b/t/integration/test_canvas.py @@ -6,22 +6,24 @@ from time import monotonic, sleep import pytest -import pytest_subtests # noqa: F401 +import pytest_subtests from celery import chain, chord, group, signature from celery.backends.base import BaseKeyValueStoreBackend +from celery.canvas import StampingVisitor from celery.exceptions import ImproperlyConfigured, TimeoutError from celery.result import AsyncResult, GroupResult, ResultSet -from celery.signals import before_task_publish +from celery.signals import before_task_publish, task_received from . import tasks from .conftest import TEST_BACKEND, get_active_redis_channels, get_redis_connection -from .tasks import (ExpectedException, add, add_chord_to_chord, add_replaced, add_to_all, add_to_all_to_chord, - build_chain_inside_task, collect_ids, delayed_sum, delayed_sum_with_soft_guard, - errback_new_style, errback_old_style, fail, fail_replaced, identity, ids, print_unicode, - raise_error, redis_count, redis_echo, redis_echo_group_id, replace_with_chain, - replace_with_chain_which_raises, replace_with_empty_chain, retry_once, return_exception, - return_priority, second_order_replace1, tsum, write_to_file_and_return_int, xsum) +from .tasks import (ExpectedException, StampOnReplace, add, add_chord_to_chord, add_replaced, add_to_all, + add_to_all_to_chord, build_chain_inside_task, collect_ids, delayed_sum, + delayed_sum_with_soft_guard, errback_new_style, errback_old_style, fail, fail_replaced, identity, + ids, mul, print_unicode, raise_error, redis_count, redis_echo, redis_echo_group_id, + replace_with_chain, replace_with_chain_which_raises, replace_with_empty_chain, + replace_with_stamped_task, retry_once, return_exception, return_priority, second_order_replace1, + tsum, write_to_file_and_return_int, xsum) RETRYABLE_EXCEPTIONS = (OSError, ConnectionError, TimeoutError) @@ -504,6 +506,21 @@ def test_chain_of_a_chord_and_three_tasks_and_a_group(self, manager): res = c() assert res.get(timeout=TIMEOUT) == [8, 8] + def test_stamping_example_canvas(self, manager): + """Test the stamping example canvas from the examples directory""" + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + c = chain( + group(identity.s(i) for i in range(1, 4)) | xsum.s(), + chord(group(mul.s(10) for _ in range(1, 4)), xsum.s()), + ) + + res = c() + assert res.get(timeout=TIMEOUT) == 180 + @pytest.mark.xfail(raises=TimeoutError, reason="Task is timeout") def test_nested_chain_group_lone(self, manager): """ @@ -601,6 +618,29 @@ def test_chain_with_cb_replaced_with_chain_with_cb(self, manager): assert res.get(timeout=TIMEOUT) == 'Hello world' await_redis_echo({link_msg, 'Hello world'}) + def test_chain_flattening_keep_links_of_inner_chain(self, manager): + if not manager.app.conf.result_backend.startswith('redis'): + raise pytest.skip('Requires redis result backend.') + + redis_connection = get_redis_connection() + + link_b_msg = 'link_b called' + link_b_key = 'echo_link_b' + link_b_sig = redis_echo.si(link_b_msg, redis_key=link_b_key) + + def link_chain(sig): + sig.link(link_b_sig) + sig.link_error(identity.s('link_ab')) + return sig + + inner_chain = link_chain(chain(identity.s('a'), add.s('b'))) + flat_chain = chain(inner_chain, add.s('c')) + redis_connection.delete(link_b_key) + res = flat_chain.delay() + + assert res.get(timeout=TIMEOUT) == 'abc' + await_redis_echo((link_b_msg,), redis_key=link_b_key) + def test_chain_with_eb_replaced_with_chain_with_eb( self, manager, subtests ): @@ -858,6 +898,154 @@ def before_task_publish_handler(sender=None, body=None, exchange=None, routing_k redis_connection = get_redis_connection() redis_connection.delete(redis_key) + def test_chaining_upgraded_chords_pure_groups(self, manager, subtests): + """ This test is built to reproduce the github issue https://github.com/celery/celery/issues/5958 + + The issue describes a canvas where a chain of groups are executed multiple times instead of once. + This test is built to reproduce the issue and to verify that the issue is fixed. + """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + if not manager.app.conf.result_backend.startswith('redis'): + raise pytest.skip('Requires redis result backend.') + + redis_connection = get_redis_connection() + redis_key = 'echo_chamber' + + c = chain( + # letting the chain upgrade the chord, reproduces the issue in _chord.__or__ + group( + redis_echo.si('1', redis_key=redis_key), + redis_echo.si('2', redis_key=redis_key), + redis_echo.si('3', redis_key=redis_key), + ), + group( + redis_echo.si('4', redis_key=redis_key), + redis_echo.si('5', redis_key=redis_key), + redis_echo.si('6', redis_key=redis_key), + ), + group( + redis_echo.si('7', redis_key=redis_key), + ), + group( + redis_echo.si('8', redis_key=redis_key), + ), + redis_echo.si('9', redis_key=redis_key), + redis_echo.si('Done', redis_key='Done'), + ) + + with subtests.test(msg='Run the chain and wait for completion'): + redis_connection.delete(redis_key, 'Done') + c.delay().get(timeout=TIMEOUT) + await_redis_list_message_length(1, redis_key='Done', timeout=10) + + with subtests.test(msg='All tasks are executed once'): + actual = [sig.decode('utf-8') for sig in redis_connection.lrange(redis_key, 0, -1)] + expected = [str(i) for i in range(1, 10)] + with subtests.test(msg='All tasks are executed once'): + assert sorted(actual) == sorted(expected) + + # Cleanup + redis_connection.delete(redis_key, 'Done') + + def test_chaining_upgraded_chords_starting_with_chord(self, manager, subtests): + """ This test is built to reproduce the github issue https://github.com/celery/celery/issues/5958 + + The issue describes a canvas where a chain of groups are executed multiple times instead of once. + This test is built to reproduce the issue and to verify that the issue is fixed. + """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + if not manager.app.conf.result_backend.startswith('redis'): + raise pytest.skip('Requires redis result backend.') + + redis_connection = get_redis_connection() + redis_key = 'echo_chamber' + + c = chain( + # by manually upgrading the chord to a group, we can reproduce the issue in _chain.__or__ + chord(group([redis_echo.si('1', redis_key=redis_key), + redis_echo.si('2', redis_key=redis_key), + redis_echo.si('3', redis_key=redis_key)]), + group([redis_echo.si('4', redis_key=redis_key), + redis_echo.si('5', redis_key=redis_key), + redis_echo.si('6', redis_key=redis_key)])), + group( + redis_echo.si('7', redis_key=redis_key), + ), + group( + redis_echo.si('8', redis_key=redis_key), + ), + redis_echo.si('9', redis_key=redis_key), + redis_echo.si('Done', redis_key='Done'), + ) + + with subtests.test(msg='Run the chain and wait for completion'): + redis_connection.delete(redis_key, 'Done') + c.delay().get(timeout=TIMEOUT) + await_redis_list_message_length(1, redis_key='Done', timeout=10) + + with subtests.test(msg='All tasks are executed once'): + actual = [sig.decode('utf-8') for sig in redis_connection.lrange(redis_key, 0, -1)] + expected = [str(i) for i in range(1, 10)] + with subtests.test(msg='All tasks are executed once'): + assert sorted(actual) == sorted(expected) + + # Cleanup + redis_connection.delete(redis_key, 'Done') + + def test_chaining_upgraded_chords_mixed_canvas(self, manager, subtests): + """ This test is built to reproduce the github issue https://github.com/celery/celery/issues/5958 + + The issue describes a canvas where a chain of groups are executed multiple times instead of once. + This test is built to reproduce the issue and to verify that the issue is fixed. + """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + if not manager.app.conf.result_backend.startswith('redis'): + raise pytest.skip('Requires redis result backend.') + + redis_connection = get_redis_connection() + redis_key = 'echo_chamber' + + c = chain( + chord(group([redis_echo.si('1', redis_key=redis_key), + redis_echo.si('2', redis_key=redis_key), + redis_echo.si('3', redis_key=redis_key)]), + group([redis_echo.si('4', redis_key=redis_key), + redis_echo.si('5', redis_key=redis_key), + redis_echo.si('6', redis_key=redis_key)])), + redis_echo.si('7', redis_key=redis_key), + group( + redis_echo.si('8', redis_key=redis_key), + ), + redis_echo.si('9', redis_key=redis_key), + redis_echo.si('Done', redis_key='Done'), + ) + + with subtests.test(msg='Run the chain and wait for completion'): + redis_connection.delete(redis_key, 'Done') + c.delay().get(timeout=TIMEOUT) + await_redis_list_message_length(1, redis_key='Done', timeout=10) + + with subtests.test(msg='All tasks are executed once'): + actual = [sig.decode('utf-8') for sig in redis_connection.lrange(redis_key, 0, -1)] + expected = [str(i) for i in range(1, 10)] + with subtests.test(msg='All tasks are executed once'): + assert sorted(actual) == sorted(expected) + + # Cleanup + redis_connection.delete(redis_key, 'Done') + class test_result_set: @@ -2953,3 +3141,209 @@ def test_rebuild_nested_chord_chord(self, manager): tasks.rebuild_signature.s() ) sig.delay().get(timeout=TIMEOUT) + + +class test_stamping_visitor: + def test_stamp_value_type_defined_by_visitor(self, manager, subtests): + """ Test that the visitor can define the type of the stamped value """ + + @before_task_publish.connect + def before_task_publish_handler(sender=None, body=None, exchange=None, routing_key=None, headers=None, + properties=None, declare=None, retry_policy=None, **kwargs): + nonlocal task_headers + task_headers = headers.copy() + + with subtests.test(msg='Test stamping a single value'): + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = add.si(1, 1) + stamped_task.stamp(visitor=CustomStampingVisitor()) + result = stamped_task.freeze() + task_headers = None + stamped_task.apply_async() + assert task_headers is not None + assert result.get() == 2 + assert 'stamps' in task_headers + assert 'stamp' in task_headers['stamps'] + assert not isinstance(task_headers['stamps']['stamp'], list) + + with subtests.test(msg='Test stamping a list of values'): + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': [4, 2]} + + stamped_task = add.si(1, 1) + stamped_task.stamp(visitor=CustomStampingVisitor()) + result = stamped_task.freeze() + task_headers = None + stamped_task.apply_async() + assert task_headers is not None + assert result.get() == 2 + assert 'stamps' in task_headers + assert 'stamp' in task_headers['stamps'] + assert isinstance(task_headers['stamps']['stamp'], list) + + def test_properties_not_affected_from_stamping(self, manager, subtests): + """ Test that the task properties are not dirty with stamping visitor entries """ + + @before_task_publish.connect + def before_task_publish_handler(sender=None, body=None, exchange=None, routing_key=None, headers=None, + properties=None, declare=None, retry_policy=None, **kwargs): + nonlocal task_headers + nonlocal task_properties + task_headers = headers.copy() + task_properties = properties.copy() + + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = add.si(1, 1) + stamped_task.stamp(visitor=CustomStampingVisitor()) + result = stamped_task.freeze() + task_headers = None + task_properties = None + stamped_task.apply_async() + assert task_properties is not None + assert result.get() == 2 + assert 'stamped_headers' in task_headers + stamped_headers = task_headers['stamped_headers'] + + with subtests.test(msg='Test that the task properties are not dirty with stamping visitor entries'): + assert 'stamped_headers' not in task_properties, 'stamped_headers key should not be in task properties' + for stamp in stamped_headers: + assert stamp not in task_properties, f'The stamp "{stamp}" should not be in the task properties' + + def test_task_received_has_access_to_stamps(self, manager): + """ Make sure that the request has the stamps using the task_received signal """ + + assertion_result = False + + @task_received.connect + def task_received_handler( + sender=None, + request=None, + signal=None, + **kwargs + ): + nonlocal assertion_result + assertion_result = all([ + stamped_header in request.stamps + for stamped_header in request.stamped_headers + ]) + + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = add.si(1, 1) + stamped_task.stamp(visitor=CustomStampingVisitor()) + stamped_task.apply_async().get() + assert assertion_result + + def test_all_tasks_of_canvas_are_stamped(self, manager, subtests): + """ Test that complex canvas are stamped correctly """ + try: + manager.app.backend.ensure_chords_allowed() + except NotImplementedError as e: + raise pytest.skip(e.args[0]) + + @task_received.connect + def task_received_handler(**kwargs): + request = kwargs['request'] + nonlocal assertion_result + + assertion_result = all([ + assertion_result, + all([stamped_header in request.stamps for stamped_header in request.stamped_headers]), + request.stamps['stamp'] == 42 + ]) + + # Using a list because pytest.mark.parametrize does not play well + canvas = [ + add.s(1, 1), + group(add.s(1, 1), add.s(2, 2)), + chain(add.s(1, 1), add.s(2, 2)), + chord([add.s(1, 1), add.s(2, 2)], xsum.s()), + chain(group(add.s(0, 0)), add.s(-1)), + add.s(1, 1) | add.s(10), + group(add.s(1, 1) | add.s(10), add.s(2, 2) | add.s(20)), + chain(add.s(1, 1) | add.s(10), add.s(2) | add.s(20)), + chord([add.s(1, 1) | add.s(10), add.s(2, 2) | add.s(20)], xsum.s()), + chain(chain(add.s(1, 1) | add.s(10), add.s(2) | add.s(20)), add.s(3) | add.s(30)), + chord(group(chain(add.s(1, 1), add.s(2)), chord([add.s(3, 3), add.s(4, 4)], xsum.s())), xsum.s()), + ] + + for sig in canvas: + with subtests.test(msg='Assert all tasks are stamped'): + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = sig + stamped_task.stamp(visitor=CustomStampingVisitor()) + assertion_result = True + stamped_task.apply_async().get() + assert assertion_result + + def test_replace_merge_stamps(self, manager): + """ Test that replacing a task keeps the previous and new stamps """ + + @task_received.connect + def task_received_handler(**kwargs): + request = kwargs['request'] + nonlocal assertion_result + expected_stamp_key = list(StampOnReplace.stamp.keys())[0] + expected_stamp_value = list(StampOnReplace.stamp.values())[0] + + assertion_result = all([ + assertion_result, + all([stamped_header in request.stamps for stamped_header in request.stamped_headers]), + request.stamps['stamp'] == 42, + request.stamps[expected_stamp_key] == expected_stamp_value + if 'replaced_with_me' in request.task_name else True + ]) + + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, sig, **headers) -> dict: + return {'stamp': 42} + + stamped_task = replace_with_stamped_task.s() + stamped_task.stamp(visitor=CustomStampingVisitor()) + assertion_result = False + stamped_task.delay() + assertion_result = True + sleep(1) + # stamped_task needs to be stamped with CustomStampingVisitor + # and the replaced task with both CustomStampingVisitor and StampOnReplace + assert assertion_result, 'All of the tasks should have been stamped' + + def test_replace_group_merge_stamps(self, manager): + """ Test that replacing a group signature keeps the previous and new group stamps """ + + x = 5 + y = 6 + + @task_received.connect + def task_received_handler(**kwargs): + request = kwargs['request'] + nonlocal assertion_result + nonlocal gid1 + + assertion_result = all([ + assertion_result, + request.stamps['groups'][0] == gid1, + len(request.stamps['groups']) == 2 + if any([request.args == [10, x], request.args == [10, y]]) else True + ]) + + sig = add.s(3, 3) | add.s(4) | group(add.s(x), add.s(y)) + sig = group(add.s(1, 1), add.s(2, 2), replace_with_stamped_task.s(replace_with=sig)) + assertion_result = False + sig.delay() + assertion_result = True + gid1 = sig.options['task_id'] + sleep(1) + assert assertion_result, 'Group stamping is corrupted' diff --git a/t/integration/test_tasks.py b/t/integration/test_tasks.py index f681da01b61..5eea4d88e9e 100644 --- a/t/integration/test_tasks.py +++ b/t/integration/test_tasks.py @@ -200,6 +200,13 @@ def test_revoked(self, manager): def test_revoked_by_headers_simple_canvas(self, manager): """Testing revoking of task using a stamped header""" + # Try to purge the queue before we start + # to attempt to avoid interference from other tests + while True: + count = manager.app.control.purge() + if count == 0: + break + target_monitoring_id = uuid4().hex class MonitoringIdStampingVisitor(StampingVisitor): @@ -227,11 +234,13 @@ def on_signature(self, sig, **headers) -> dict: assert result.successful() is True worker_state.revoked_headers.clear() - # This test leaves the environment dirty, - # so we let it run last in the suite to avoid - # affecting other tests until we can fix it. - @pytest.mark.order("last") - @flaky + # Try to purge the queue after we're done + # to attempt to avoid interference to other tests + while True: + count = manager.app.control.purge() + if count == 0: + break + def test_revoked_by_headers_complex_canvas(self, manager, subtests): """Testing revoking of task using a stamped header""" try: @@ -285,6 +294,13 @@ def on_signature(self, sig, **headers) -> dict: assert result.successful() is False worker_state.revoked_headers.clear() + # Try to purge the queue after we're done + # to attempt to avoid interference to other tests + while True: + count = manager.app.control.purge() + if count == 0: + break + @flaky def test_wrong_arguments(self, manager): """Tests that proper exceptions are raised when task is called with wrong arguments.""" diff --git a/t/unit/app/test_amqp.py b/t/unit/app/test_amqp.py index 1010c4c64ce..070002d43f4 100644 --- a/t/unit/app/test_amqp.py +++ b/t/unit/app/test_amqp.py @@ -206,7 +206,7 @@ def test_as_task_message_without_utc(self): class test_AMQP_Base: - def setup(self): + def setup_method(self): self.simple_message = self.app.amqp.as_task_v2( uuid(), 'foo', create_sent_event=True, ) diff --git a/t/unit/app/test_annotations.py b/t/unit/app/test_annotations.py index e262e23ce84..7b13d37ef6a 100644 --- a/t/unit/app/test_annotations.py +++ b/t/unit/app/test_annotations.py @@ -8,7 +8,7 @@ class MyAnnotation: class AnnotationCase: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y diff --git a/t/unit/app/test_app.py b/t/unit/app/test_app.py index 844934b71b1..9d504f9fcc4 100644 --- a/t/unit/app/test_app.py +++ b/t/unit/app/test_app.py @@ -71,7 +71,7 @@ def test_task_join_will_block(self, patching): class test_App: - def setup(self): + def setup_method(self): self.app.add_defaults(deepcopy(self.CELERY_TEST_CONFIG)) def test_now(self): diff --git a/t/unit/app/test_beat.py b/t/unit/app/test_beat.py index 445aa28ed86..84f36d04f86 100644 --- a/t/unit/app/test_beat.py +++ b/t/unit/app/test_beat.py @@ -99,9 +99,9 @@ def test_lt(self): e1 = self.create_entry(schedule=timedelta(seconds=10)) e2 = self.create_entry(schedule=timedelta(seconds=2)) # order doesn't matter, see comment in __lt__ - res1 = e1 < e2 # noqa + res1 = e1 < e2 try: - res2 = e1 < object() # noqa + res2 = e1 < object() except TypeError: pass @@ -696,16 +696,19 @@ def now_func(): 'first_missed', 'first_missed', last_run_at=now_func() - timedelta(minutes=2), total_run_count=10, + app=self.app, schedule=app_schedule['first_missed']['schedule']), 'second_missed': beat.ScheduleEntry( 'second_missed', 'second_missed', last_run_at=now_func() - timedelta(minutes=2), total_run_count=10, + app=self.app, schedule=app_schedule['second_missed']['schedule']), 'non_missed': beat.ScheduleEntry( 'non_missed', 'non_missed', last_run_at=now_func() - timedelta(minutes=2), total_run_count=10, + app=self.app, schedule=app_schedule['non_missed']['schedule']), } diff --git a/t/unit/app/test_builtins.py b/t/unit/app/test_builtins.py index dcbec4b201b..94ab14e9c97 100644 --- a/t/unit/app/test_builtins.py +++ b/t/unit/app/test_builtins.py @@ -10,7 +10,7 @@ class BuiltinsCase: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def xsum(x): return sum(x) @@ -34,7 +34,7 @@ def test_run(self): class test_accumulate(BuiltinsCase): - def setup(self): + def setup_method(self): self.accumulate = self.app.tasks['celery.accumulate'] def test_with_index(self): @@ -89,7 +89,7 @@ def chunks_mul(l): class test_group(BuiltinsCase): - def setup(self): + def setup_method(self): self.maybe_signature = self.patching('celery.canvas.maybe_signature') self.maybe_signature.side_effect = pass1 self.app.producer_or_acquire = Mock() @@ -98,7 +98,7 @@ def setup(self): ) self.app.conf.task_always_eager = True self.task = builtins.add_group_task(self.app) - super().setup() + super().setup_method() def test_apply_async_eager(self): self.task.apply = Mock(name='apply') @@ -132,8 +132,8 @@ def test_task__disable_add_to_parent(self, current_worker_task): class test_chain(BuiltinsCase): - def setup(self): - super().setup() + def setup_method(self): + super().setup_method() self.task = builtins.add_chain_task(self.app) def test_not_implemented(self): @@ -143,9 +143,9 @@ def test_not_implemented(self): class test_chord(BuiltinsCase): - def setup(self): + def setup_method(self): self.task = builtins.add_chord_task(self.app) - super().setup() + super().setup_method() def test_apply_async(self): x = chord([self.add.s(i, i) for i in range(10)], body=self.xsum.s()) diff --git a/t/unit/app/test_control.py b/t/unit/app/test_control.py index eb6a761e837..0908491a9ee 100644 --- a/t/unit/app/test_control.py +++ b/t/unit/app/test_control.py @@ -52,7 +52,7 @@ def test_flatten_reply(self): class test_inspect: - def setup(self): + def setup_method(self): self.app.control.broadcast = Mock(name='broadcast') self.app.control.broadcast.return_value = {} self.inspect = self.app.control.inspect() @@ -207,7 +207,7 @@ def test_report(self): class test_Control_broadcast: - def setup(self): + def setup_method(self): self.app.control.mailbox = Mock(name='mailbox') def test_broadcast(self): @@ -231,7 +231,7 @@ def test_broadcast_limit(self): class test_Control: - def setup(self): + def setup_method(self): self.app.control.broadcast = Mock(name='broadcast') self.app.control.broadcast.return_value = {} diff --git a/t/unit/app/test_defaults.py b/t/unit/app/test_defaults.py index 649ca4aab7d..509718d6b86 100644 --- a/t/unit/app/test_defaults.py +++ b/t/unit/app/test_defaults.py @@ -7,10 +7,10 @@ class test_defaults: - def setup(self): + def setup_method(self): self._prev = sys.modules.pop('celery.app.defaults', None) - def teardown(self): + def teardown_method(self): if self._prev: sys.modules['celery.app.defaults'] = self._prev diff --git a/t/unit/app/test_loaders.py b/t/unit/app/test_loaders.py index 09c8a6fe775..879887ebe9e 100644 --- a/t/unit/app/test_loaders.py +++ b/t/unit/app/test_loaders.py @@ -35,7 +35,7 @@ class test_LoaderBase: 'password': 'qwerty', 'timeout': 3} - def setup(self): + def setup_method(self): self.loader = DummyLoader(app=self.app) def test_handlers_pass(self): @@ -212,7 +212,7 @@ def find_module(self, name): class test_AppLoader: - def setup(self): + def setup_method(self): self.loader = AppLoader(app=self.app) def test_on_worker_init(self): diff --git a/t/unit/app/test_log.py b/t/unit/app/test_log.py index c3a425447a3..3be3db3a70b 100644 --- a/t/unit/app/test_log.py +++ b/t/unit/app/test_log.py @@ -150,7 +150,7 @@ def setup_logger(self, *args, **kwargs): return logging.root - def setup(self): + def setup_method(self): self.get_logger = lambda n=None: get_logger(n) if n else logging.root signals.setup_logging.receivers[:] = [] self.app.log.already_setup = False @@ -312,7 +312,7 @@ def test_logging_proxy_recurse_protection(self, restore_logging): class test_task_logger(test_default_logger): - def setup(self): + def setup_method(self): logger = self.logger = get_logger('celery.task') logger.handlers = [] logging.root.manager.loggerDict.pop(logger.name, None) @@ -326,7 +326,7 @@ def test_task(): from celery._state import _task_stack _task_stack.push(test_task) - def teardown(self): + def teardown_method(self): from celery._state import _task_stack _task_stack.pop() diff --git a/t/unit/app/test_registry.py b/t/unit/app/test_registry.py index 577c42e8764..8bd8ae5dbcf 100644 --- a/t/unit/app/test_registry.py +++ b/t/unit/app/test_registry.py @@ -23,7 +23,7 @@ def test_unpickle_v2(self, app): class test_TaskRegistry: - def setup(self): + def setup_method(self): self.mytask = self.app.task(name='A', shared=False)(returns) self.missing_name_task = self.app.task( name=None, shared=False)(returns) diff --git a/t/unit/app/test_routes.py b/t/unit/app/test_routes.py index fbb2803b4d1..775bbf7abd9 100644 --- a/t/unit/app/test_routes.py +++ b/t/unit/app/test_routes.py @@ -27,7 +27,7 @@ def set_queues(app, **queues): class RouteCase: - def setup(self): + def setup_method(self): self.a_queue = { 'exchange': 'fooexchange', 'exchange_type': 'fanout', diff --git a/t/unit/app/test_schedules.py b/t/unit/app/test_schedules.py index 8f49b5963b0..d6f555c2cf2 100644 --- a/t/unit/app/test_schedules.py +++ b/t/unit/app/test_schedules.py @@ -25,8 +25,8 @@ def patch_crontab_nowfun(cls, retval): class test_solar: - def setup(self): - pytest.importorskip('ephem0') + def setup_method(self): + pytest.importorskip('ephem') self.s = solar('sunrise', 60, 30, app=self.app) def test_reduce(self): @@ -475,7 +475,7 @@ def test_day_after_dst_start(self): class test_crontab_is_due: - def setup(self): + def setup_method(self): self.now = self.app.now() self.next_minute = 60 - self.now.second - 1e-6 * self.now.microsecond self.every_minute = self.crontab() @@ -800,3 +800,136 @@ def test_yearly_execution_is_not_due(self): due, remaining = self.yearly.is_due(datetime(2009, 3, 12, 7, 30)) assert not due assert remaining == 4 * 24 * 60 * 60 - 3 * 60 * 60 + + def test_execution_not_due_if_task_not_run_at_last_feasible_time_outside_deadline( + self): + """If the crontab schedule was added after the task was due, don't + immediately fire the task again""" + # could have feasibly been run on 12/5 at 7:30, but wasn't. + self.app.conf.beat_cron_starting_deadline = 3600 + last_run = datetime(2022, 12, 4, 10, 30) + now = datetime(2022, 12, 5, 10, 30) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert not due + + def test_execution_not_due_if_task_not_run_at_last_feasible_time_no_deadline_set( + self): + """Same as above test except there's no deadline set, so it should be + due""" + last_run = datetime(2022, 12, 4, 10, 30) + now = datetime(2022, 12, 5, 10, 30) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert due + + def test_execution_due_if_task_not_run_at_last_feasible_time_within_deadline( + self): + # Could have feasibly been run on 12/5 at 7:30, but wasn't. We are + # still within a 1 hour deadline from the + # last feasible run, so the task should still be due. + self.app.conf.beat_cron_starting_deadline = 3600 + last_run = datetime(2022, 12, 4, 10, 30) + now = datetime(2022, 12, 5, 8, 0) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert due + + def test_execution_due_if_task_not_run_at_any_feasible_time_within_deadline( + self): + # Could have feasibly been run on 12/4 at 7:30, or 12/5 at 7:30, + # but wasn't. We are still within a 1 hour + # deadline from the last feasible run (12/5), so the task should + # still be due. + self.app.conf.beat_cron_starting_deadline = 3600 + last_run = datetime(2022, 12, 3, 10, 30) + now = datetime(2022, 12, 5, 8, 0) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert due + + def test_execution_not_due_if_task_not_run_at_any_feasible_time_outside_deadline( + self): + """Verifies that remaining is still the time to the next + feasible run date even though the original feasible date + was passed over in favor of a newer one.""" + # Could have feasibly been run on 12/4 or 12/5 at 7:30, + # but wasn't. + self.app.conf.beat_cron_starting_deadline = 3600 + last_run = datetime(2022, 12, 3, 10, 30) + now = datetime(2022, 12, 5, 11, 0) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert not due + + def test_execution_not_due_if_last_run_in_future(self): + # Should not run if the last_run hasn't happened yet. + last_run = datetime(2022, 12, 6, 7, 30) + now = datetime(2022, 12, 5, 10, 30) + expected_next_execution_time = datetime(2022, 12, 7, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert not due + assert remaining == expected_remaining + + def test_execution_not_due_if_last_run_at_last_feasible_time(self): + # Last feasible time is 12/5 at 7:30 + last_run = datetime(2022, 12, 5, 7, 30) + now = datetime(2022, 12, 5, 10, 30) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert not due + + def test_execution_not_due_if_last_run_past_last_feasible_time(self): + # Last feasible time is 12/5 at 7:30 + last_run = datetime(2022, 12, 5, 8, 30) + now = datetime(2022, 12, 5, 10, 30) + expected_next_execution_time = datetime(2022, 12, 6, 7, 30) + expected_remaining = ( + expected_next_execution_time - now).total_seconds() + + # Run the daily (7:30) crontab with the current date + with patch_crontab_nowfun(self.daily, now): + due, remaining = self.daily.is_due(last_run) + assert remaining == expected_remaining + assert not due diff --git a/t/unit/apps/test_multi.py b/t/unit/apps/test_multi.py index a5c4c0e6c3a..2690872292b 100644 --- a/t/unit/apps/test_multi.py +++ b/t/unit/apps/test_multi.py @@ -172,7 +172,7 @@ def test_optmerge(self): class test_Node: - def setup(self): + def setup_method(self): self.p = Mock(name='p') self.p.options = { '--executable': 'python', @@ -308,7 +308,7 @@ def test_pidfile_custom(self, mock_exists, mock_dirs): class test_Cluster: - def setup(self): + def setup_method(self): self.Popen = self.patching('celery.apps.multi.Popen') self.kill = self.patching('os.kill') self.gethostname = self.patching('celery.apps.multi.gethostname') diff --git a/t/unit/backends/test_arangodb.py b/t/unit/backends/test_arangodb.py index 4486f0b52c0..c35fb162c78 100644 --- a/t/unit/backends/test_arangodb.py +++ b/t/unit/backends/test_arangodb.py @@ -19,7 +19,7 @@ class test_ArangoDbBackend: - def setup(self): + def setup_method(self): self.backend = ArangoDbBackend(app=self.app) def test_init_no_arangodb(self): diff --git a/t/unit/backends/test_azureblockblob.py b/t/unit/backends/test_azureblockblob.py index 5329140627f..36ca91d82cb 100644 --- a/t/unit/backends/test_azureblockblob.py +++ b/t/unit/backends/test_azureblockblob.py @@ -14,7 +14,7 @@ class test_AzureBlockBlobBackend: - def setup(self): + def setup_method(self): self.url = ( "azureblockblob://" "DefaultEndpointsProtocol=protocol;" @@ -168,7 +168,7 @@ def test_base_path_conf_default(self): class test_as_uri: - def setup(self): + def setup_method(self): self.url = ( "azureblockblob://" "DefaultEndpointsProtocol=protocol;" diff --git a/t/unit/backends/test_base.py b/t/unit/backends/test_base.py index b9084522d25..d520a5d3608 100644 --- a/t/unit/backends/test_base.py +++ b/t/unit/backends/test_base.py @@ -1,10 +1,11 @@ +import copy import re from contextlib import contextmanager from unittest.mock import ANY, MagicMock, Mock, call, patch, sentinel import pytest from kombu.serialization import prepare_accept_content -from kombu.utils.encoding import ensure_bytes +from kombu.utils.encoding import bytes_to_str, ensure_bytes import celery from celery import chord, group, signature, states, uuid @@ -68,7 +69,7 @@ def test_create_exception_cls(self): class test_Backend_interface: - def setup(self): + def setup_method(self): self.app.conf.accept_content = ['json'] def test_accept_precedence(self): @@ -166,7 +167,7 @@ def test_get_result_meta_with_none(self): class test_BaseBackend_interface: - def setup(self): + def setup_method(self): self.b = BaseBackend(self.app) @self.app.task(shared=False) @@ -260,7 +261,7 @@ def test_unpickleable(self): class test_prepare_exception: - def setup(self): + def setup_method(self): self.b = BaseBackend(self.app) def test_unpickleable(self): @@ -358,7 +359,7 @@ def _delete_group(self, group_id): class test_BaseBackend_dict: - def setup(self): + def setup_method(self): self.b = DictBackend(app=self.app) @self.app.task(shared=False, bind=True) @@ -649,7 +650,7 @@ def test_get_children(self): class test_KeyValueStoreBackend: - def setup(self): + def setup_method(self): self.b = KVBackend(app=self.app) def test_on_chord_part_return(self): @@ -722,6 +723,22 @@ def test_strip_prefix(self): assert self.b._strip_prefix(x) == 'x1b34' assert self.b._strip_prefix('x1b34') == 'x1b34' + def test_global_keyprefix(self): + global_keyprefix = "test_global_keyprefix_" + app = copy.deepcopy(self.app) + app.conf.get('result_backend_transport_options', {}).update({"global_keyprefix": global_keyprefix}) + b = KVBackend(app=app) + tid = uuid() + assert bytes_to_str(b.get_key_for_task(tid)) == f"{global_keyprefix}_celery-task-meta-{tid}" + assert bytes_to_str(b.get_key_for_group(tid)) == f"{global_keyprefix}_celery-taskset-meta-{tid}" + assert bytes_to_str(b.get_key_for_chord(tid)) == f"{global_keyprefix}_chord-unlock-{tid}" + + def test_global_keyprefix_missing(self): + tid = uuid() + assert bytes_to_str(self.b.get_key_for_task(tid)) == f"celery-task-meta-{tid}" + assert bytes_to_str(self.b.get_key_for_group(tid)) == f"celery-taskset-meta-{tid}" + assert bytes_to_str(self.b.get_key_for_chord(tid)) == f"chord-unlock-{tid}" + def test_get_many(self): for is_dict in True, False: self.b.mget_returns_dict = is_dict @@ -1014,7 +1031,7 @@ def test_chain_with_chord_raises_error(self): class test_as_uri: - def setup(self): + def setup_method(self): self.b = BaseBackend( app=self.app, url='sch://uuuu:pwpw@hostname.dom' diff --git a/t/unit/backends/test_cache.py b/t/unit/backends/test_cache.py index 40ae4277331..a82d0bbcfb9 100644 --- a/t/unit/backends/test_cache.py +++ b/t/unit/backends/test_cache.py @@ -20,14 +20,14 @@ def __init__(self, data): class test_CacheBackend: - def setup(self): + def setup_method(self): self.app.conf.result_serializer = 'pickle' self.tb = CacheBackend(backend='memory://', app=self.app) self.tid = uuid() self.old_get_best_memcached = backends['memcache'] backends['memcache'] = lambda: (DummyClient, ensure_bytes) - def teardown(self): + def teardown_method(self): backends['memcache'] = self.old_get_best_memcached def test_no_backend(self): @@ -143,7 +143,7 @@ def test_as_uri_multiple_servers(self): assert b.as_uri() == backend def test_regression_worker_startup_info(self): - pytest.importorskip('memcached') + pytest.importorskip('memcache') self.app.conf.result_backend = ( 'cache+memcached://127.0.0.1:11211;127.0.0.2:11211;127.0.0.3/' ) diff --git a/t/unit/backends/test_cassandra.py b/t/unit/backends/test_cassandra.py index 75d8818bcd1..9bf8a480f3d 100644 --- a/t/unit/backends/test_cassandra.py +++ b/t/unit/backends/test_cassandra.py @@ -18,7 +18,7 @@ class test_CassandraBackend: - def setup(self): + def setup_method(self): self.app.conf.update( cassandra_servers=['example.com'], cassandra_keyspace='celery', diff --git a/t/unit/backends/test_consul.py b/t/unit/backends/test_consul.py index 61fb5d41afd..cec77360490 100644 --- a/t/unit/backends/test_consul.py +++ b/t/unit/backends/test_consul.py @@ -9,7 +9,7 @@ class test_ConsulBackend: - def setup(self): + def setup_method(self): self.backend = ConsulBackend( app=self.app, url='consul://localhost:800') diff --git a/t/unit/backends/test_cosmosdbsql.py b/t/unit/backends/test_cosmosdbsql.py index 3ee85df43dc..bfd0d0d1e1f 100644 --- a/t/unit/backends/test_cosmosdbsql.py +++ b/t/unit/backends/test_cosmosdbsql.py @@ -13,7 +13,7 @@ class test_DocumentDBBackend: - def setup(self): + def setup_method(self): self.url = "cosmosdbsql://:key@endpoint" self.backend = CosmosDBSQLBackend(app=self.app, url=self.url) diff --git a/t/unit/backends/test_couchbase.py b/t/unit/backends/test_couchbase.py index 297735a38ba..b720b2525c5 100644 --- a/t/unit/backends/test_couchbase.py +++ b/t/unit/backends/test_couchbase.py @@ -22,7 +22,7 @@ class test_CouchbaseBackend: - def setup(self): + def setup_method(self): self.backend = CouchbaseBackend(app=self.app) def test_init_no_couchbase(self): diff --git a/t/unit/backends/test_couchdb.py b/t/unit/backends/test_couchdb.py index 41505594f72..07497b18cec 100644 --- a/t/unit/backends/test_couchdb.py +++ b/t/unit/backends/test_couchdb.py @@ -20,7 +20,7 @@ class test_CouchBackend: - def setup(self): + def setup_method(self): self.Server = self.patching('pycouchdb.Server') self.backend = CouchBackend(app=self.app) diff --git a/t/unit/backends/test_database.py b/t/unit/backends/test_database.py index c32440b2fe4..d6b03145056 100644 --- a/t/unit/backends/test_database.py +++ b/t/unit/backends/test_database.py @@ -10,10 +10,10 @@ pytest.importorskip('sqlalchemy') -from celery.backends.database import DatabaseBackend, retry, session, session_cleanup # noqa -from celery.backends.database.models import Task, TaskSet # noqa -from celery.backends.database.session import PREPARE_MODELS_MAX_RETRIES, ResultModelBase, SessionManager # noqa -from t import skip # noqa +from celery.backends.database import DatabaseBackend, retry, session, session_cleanup +from celery.backends.database.models import Task, TaskSet +from celery.backends.database.session import PREPARE_MODELS_MAX_RETRIES, ResultModelBase, SessionManager +from t import skip class SomeClass: @@ -45,7 +45,7 @@ def test_context_raises(self): @skip.if_pypy class test_DatabaseBackend: - def setup(self): + def setup_method(self): self.uri = 'sqlite:///test.db' self.app.conf.result_serializer = 'pickle' @@ -219,7 +219,7 @@ def test_TaskSet__repr__(self): @skip.if_pypy class test_DatabaseBackend_result_extended(): - def setup(self): + def setup_method(self): self.uri = 'sqlite:///test.db' self.app.conf.result_serializer = 'pickle' self.app.conf.result_extended = True diff --git a/t/unit/backends/test_dynamodb.py b/t/unit/backends/test_dynamodb.py index a27af96d6ff..0afb425e1d1 100644 --- a/t/unit/backends/test_dynamodb.py +++ b/t/unit/backends/test_dynamodb.py @@ -12,7 +12,7 @@ class test_DynamoDBBackend: - def setup(self): + def setup_method(self): self._static_timestamp = Decimal(1483425566.52) self.app.conf.result_backend = 'dynamodb://' diff --git a/t/unit/backends/test_elasticsearch.py b/t/unit/backends/test_elasticsearch.py index c39419eb52b..45f8a6fb092 100644 --- a/t/unit/backends/test_elasticsearch.py +++ b/t/unit/backends/test_elasticsearch.py @@ -31,7 +31,7 @@ class test_ElasticsearchBackend: - def setup(self): + def setup_method(self): self.backend = ElasticsearchBackend(app=self.app) def test_init_no_elasticsearch(self): diff --git a/t/unit/backends/test_filesystem.py b/t/unit/backends/test_filesystem.py index 4fb46683f4f..7f66a6aeae3 100644 --- a/t/unit/backends/test_filesystem.py +++ b/t/unit/backends/test_filesystem.py @@ -17,7 +17,7 @@ @t.skip.if_win32 class test_FilesystemBackend: - def setup(self): + def setup_method(self): self.directory = tempfile.mkdtemp() self.url = 'file://' + self.directory self.path = self.directory.encode('ascii') diff --git a/t/unit/backends/test_mongodb.py b/t/unit/backends/test_mongodb.py index c15ded834f1..a0bb8169ea3 100644 --- a/t/unit/backends/test_mongodb.py +++ b/t/unit/backends/test_mongodb.py @@ -77,7 +77,7 @@ class test_MongoBackend: 'hostname.dom/database?replicaSet=rs' ) - def setup(self): + def setup_method(self): self.patching('celery.backends.mongodb.MongoBackend.encode') self.patching('celery.backends.mongodb.MongoBackend.decode') self.patching('celery.backends.mongodb.Binary') diff --git a/t/unit/backends/test_redis.py b/t/unit/backends/test_redis.py index 1643c165956..dbb11db8e3e 100644 --- a/t/unit/backends/test_redis.py +++ b/t/unit/backends/test_redis.py @@ -358,7 +358,7 @@ def chord_context(self, size=1): callback.delay = Mock(name='callback.delay') yield tasks, request, callback - def setup(self): + def setup_method(self): self.Backend = self.get_backend() self.E_LOST = self.get_E_LOST() self.b = self.Backend(app=self.app) @@ -1193,7 +1193,7 @@ def get_E_LOST(self): from celery.backends.redis import E_LOST return E_LOST - def setup(self): + def setup_method(self): self.Backend = self.get_backend() self.E_LOST = self.get_E_LOST() self.b = self.Backend(app=self.app) diff --git a/t/unit/backends/test_rpc.py b/t/unit/backends/test_rpc.py index 71e573da8ff..5d37689a31d 100644 --- a/t/unit/backends/test_rpc.py +++ b/t/unit/backends/test_rpc.py @@ -23,7 +23,7 @@ def test_drain_events_before_start(self): class test_RPCBackend: - def setup(self): + def setup_method(self): self.b = RPCBackend(app=self.app) def test_oid(self): diff --git a/t/unit/bin/proj/app2.py b/t/unit/bin/proj/app2.py index 1eedbda5718..3eb4a20a0eb 100644 --- a/t/unit/bin/proj/app2.py +++ b/t/unit/bin/proj/app2.py @@ -1 +1 @@ -import celery # noqa: F401 +import celery diff --git a/t/unit/concurrency/test_concurrency.py b/t/unit/concurrency/test_concurrency.py index 1a3267bfabf..ba80aa98ec5 100644 --- a/t/unit/concurrency/test_concurrency.py +++ b/t/unit/concurrency/test_concurrency.py @@ -109,6 +109,7 @@ def test_interface_on_apply(self): def test_interface_info(self): assert BasePool(10).info == { + 'implementation': 'celery.concurrency.base:BasePool', 'max-concurrency': 10, } @@ -166,6 +167,7 @@ def test_no_concurrent_futures__returns_no_threads_pool_name(self): 'gevent', 'solo', 'processes', + 'custom', ) with patch.dict(sys.modules, {'concurrent.futures': None}): importlib.reload(concurrency) @@ -179,6 +181,7 @@ def test_concurrent_futures__returns_threads_pool_name(self): 'solo', 'processes', 'threads', + 'custom', ) with patch.dict(sys.modules, {'concurrent.futures': Mock()}): importlib.reload(concurrency) diff --git a/t/unit/concurrency/test_eventlet.py b/t/unit/concurrency/test_eventlet.py index b6a46d95ceb..30b57dae0b1 100644 --- a/t/unit/concurrency/test_eventlet.py +++ b/t/unit/concurrency/test_eventlet.py @@ -5,10 +5,10 @@ pytest.importorskip('eventlet') -from greenlet import GreenletExit # noqa +from greenlet import GreenletExit -import t.skip # noqa -from celery.concurrency.eventlet import TaskPool, Timer, apply_target # noqa +import t.skip +from celery.concurrency.eventlet import TaskPool, Timer, apply_target eventlet_modules = ( 'eventlet', @@ -22,10 +22,10 @@ @t.skip.if_pypy class EventletCase: - def setup(self): + def setup_method(self): self.patching.modules(*eventlet_modules) - def teardown(self): + def teardown_method(self): for mod in [mod for mod in sys.modules if mod.startswith('eventlet')]: try: @@ -129,6 +129,7 @@ def test_get_info(self): x = TaskPool(10) x._pool = Mock(name='_pool') assert x._get_info() == { + 'implementation': 'celery.concurrency.eventlet:TaskPool', 'max-concurrency': 10, 'free-threads': x._pool.free(), 'running-threads': x._pool.running(), diff --git a/t/unit/concurrency/test_gevent.py b/t/unit/concurrency/test_gevent.py index 89a8398ec3b..c0b24001d90 100644 --- a/t/unit/concurrency/test_gevent.py +++ b/t/unit/concurrency/test_gevent.py @@ -26,7 +26,7 @@ def test_is_patched(self): class test_Timer: - def setup(self): + def setup_method(self): self.patching.modules(*gevent_modules) self.greenlet = self.patching('gevent.greenlet') self.GreenletExit = self.patching('gevent.greenlet.GreenletExit') @@ -57,7 +57,7 @@ def test_sched(self): class test_TaskPool: - def setup(self): + def setup_method(self): self.patching.modules(*gevent_modules) self.spawn_raw = self.patching('gevent.spawn_raw') self.Pool = self.patching('gevent.pool.Pool') diff --git a/t/unit/concurrency/test_pool.py b/t/unit/concurrency/test_pool.py index 5661f13760f..1e2d70afa83 100644 --- a/t/unit/concurrency/test_pool.py +++ b/t/unit/concurrency/test_pool.py @@ -24,7 +24,7 @@ def raise_something(i): class test_TaskPool: - def setup(self): + def setup_method(self): from celery.concurrency.prefork import TaskPool self.TaskPool = TaskPool diff --git a/t/unit/concurrency/test_prefork.py b/t/unit/concurrency/test_prefork.py index 194dec78aea..49b80c17f0c 100644 --- a/t/unit/concurrency/test_prefork.py +++ b/t/unit/concurrency/test_prefork.py @@ -194,7 +194,7 @@ class ExeMockTaskPool(mp.TaskPool): @t.skip.if_win32 class test_AsynPool: - def setup(self): + def setup_method(self): pytest.importorskip('multiprocessing') def test_gen_not_started(self): @@ -369,7 +369,7 @@ def test_register_with_event_loop__no_on_tick_dupes(self): @t.skip.if_win32 class test_ResultHandler: - def setup(self): + def setup_method(self): pytest.importorskip('multiprocessing') def test_process_result(self): diff --git a/t/unit/contrib/proj/foo.py b/t/unit/contrib/proj/foo.py index b6e3d656110..07a628b781c 100644 --- a/t/unit/contrib/proj/foo.py +++ b/t/unit/contrib/proj/foo.py @@ -1,4 +1,4 @@ -from xyzzy import plugh # noqa +from xyzzy import plugh from celery import Celery, shared_task diff --git a/t/unit/contrib/test_abortable.py b/t/unit/contrib/test_abortable.py index 9edc8435ae4..3c3d55344ff 100644 --- a/t/unit/contrib/test_abortable.py +++ b/t/unit/contrib/test_abortable.py @@ -3,7 +3,7 @@ class test_AbortableTask: - def setup(self): + def setup_method(self): @self.app.task(base=AbortableTask, shared=False) def abortable(): return True diff --git a/t/unit/contrib/test_sphinx.py b/t/unit/contrib/test_sphinx.py index a4d74e04465..0b2bad28509 100644 --- a/t/unit/contrib/test_sphinx.py +++ b/t/unit/contrib/test_sphinx.py @@ -3,7 +3,7 @@ import pytest try: - from sphinx.application import Sphinx # noqa: F401 + from sphinx.application import Sphinx from sphinx_testing import TestApp sphinx_installed = True except ImportError: diff --git a/t/unit/contrib/test_worker.py b/t/unit/contrib/test_worker.py index f2ccf0625bd..17cf005f175 100644 --- a/t/unit/contrib/test_worker.py +++ b/t/unit/contrib/test_worker.py @@ -2,13 +2,13 @@ # this import adds a @shared_task, which uses connect_on_app_finalize # to install the celery.ping task that the test lib uses -import celery.contrib.testing.tasks # noqa: F401 +import celery.contrib.testing.tasks from celery import Celery from celery.contrib.testing.worker import start_worker class test_worker: - def setup(self): + def setup_method(self): self.app = Celery('celerytest', backend='cache+memory://', broker='memory://',) @self.app.task diff --git a/t/unit/events/test_cursesmon.py b/t/unit/events/test_cursesmon.py index 17cce119fed..fa0816050de 100644 --- a/t/unit/events/test_cursesmon.py +++ b/t/unit/events/test_cursesmon.py @@ -11,7 +11,7 @@ def getmaxyx(self): class test_CursesDisplay: - def setup(self): + def setup_method(self): from celery.events import cursesmon self.monitor = cursesmon.CursesMonitor(object(), app=self.app) self.win = MockWindow() diff --git a/t/unit/events/test_snapshot.py b/t/unit/events/test_snapshot.py index 3dfb01846e9..c09d67d10e5 100644 --- a/t/unit/events/test_snapshot.py +++ b/t/unit/events/test_snapshot.py @@ -19,7 +19,7 @@ def call_repeatedly(self, secs, fun, *args, **kwargs): class test_Polaroid: - def setup(self): + def setup_method(self): self.state = self.app.events.State() def test_constructor(self): @@ -101,7 +101,7 @@ class MockEvents(Events): def Receiver(self, *args, **kwargs): return test_evcam.MockReceiver() - def setup(self): + def setup_method(self): self.app.events = self.MockEvents() self.app.events.app = self.app diff --git a/t/unit/security/case.py b/t/unit/security/case.py index 36f0e5e4c95..319853dbfda 100644 --- a/t/unit/security/case.py +++ b/t/unit/security/case.py @@ -3,5 +3,5 @@ class SecurityCase: - def setup(self): + def setup_method(self): pytest.importorskip('cryptography') diff --git a/t/unit/security/test_security.py b/t/unit/security/test_security.py index 0559919997e..fc9a5e69004 100644 --- a/t/unit/security/test_security.py +++ b/t/unit/security/test_security.py @@ -33,7 +33,7 @@ class test_security(SecurityCase): - def teardown(self): + def teardown_method(self): registry._disabled_content_types.clear() registry._set_default_serializer('json') try: diff --git a/t/unit/tasks/test_canvas.py b/t/unit/tasks/test_canvas.py index 493ce04d50a..1c23b4fa693 100644 --- a/t/unit/tasks/test_canvas.py +++ b/t/unit/tasks/test_canvas.py @@ -3,12 +3,13 @@ from unittest.mock import ANY, MagicMock, Mock, call, patch, sentinel import pytest -import pytest_subtests # noqa: F401 +import pytest_subtests from celery import Task from celery._state import _task_stack -from celery.canvas import (GroupStampingVisitor, Signature, StampingVisitor, _chain, _maybe_group, chain, chord, - chunks, group, maybe_signature, maybe_unroll_group, signature, xmap, xstarmap) +from celery.canvas import (GroupStampingVisitor, Signature, StampingVisitor, _chain, _maybe_group, + _merge_dictionaries, chain, chord, chunks, group, maybe_signature, maybe_unroll_group, + signature, xmap, xstarmap) from celery.exceptions import Ignore from celery.result import AsyncResult, EagerResult, GroupResult @@ -44,7 +45,7 @@ def test_when_no_len_and_no_length_hint(self): class CanvasCase: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y @@ -137,6 +138,20 @@ def __init__(self, *args, **kwargs): class test_Signature(CanvasCase): + @pytest.mark.usefixtures('depends_on_current_app') + def test_on_signature_gets_the_signature(self): + expected_sig = self.add.s(4, 2) + + class CustomStampingVisitor(StampingVisitor): + def on_signature(self, actual_sig, **headers) -> dict: + nonlocal expected_sig + assert actual_sig == expected_sig + return {'header': 'value'} + + sig = expected_sig.clone() + sig.stamp(CustomStampingVisitor()) + assert sig.options['header'] == 'value' + def test_double_stamping(self, subtests): """ Test manual signature stamping with two different stamps. @@ -440,7 +455,7 @@ def test_flatten_links(self): tasks[1].link(tasks[2]) assert tasks[0].flatten_links() == tasks - def test_OR(self): + def test_OR(self, subtests): x = self.add.s(2, 2) | self.mul.s(4) assert isinstance(x, _chain) y = self.add.s(4, 4) | self.div.s(2) @@ -454,6 +469,10 @@ def test_OR(self): assert isinstance(ax, _chain) assert len(ax.tasks), 3 == 'consolidates chain to chain' + with subtests.test('Test chaining with a non-signature object'): + with pytest.raises(TypeError): + assert signature('foo') | None + def test_INVERT(self): x = self.add.s(2, 2) x.apply_async = Mock() @@ -563,6 +582,32 @@ def test_keeping_link_error_on_chaining(self): assert SIG in x.options['link_error'] assert not x.tasks[0].options.get('link_error') + def test_signature_on_error_adds_error_callback(self): + sig = signature('sig').on_error(signature('on_error')) + assert sig.options['link_error'] == [signature('on_error')] + + @pytest.mark.parametrize('_id, group_id, chord, root_id, parent_id, group_index', [ + ('_id', 'group_id', 'chord', 'root_id', 'parent_id', 1), + ]) + def test_freezing_args_set_in_options(self, _id, group_id, chord, root_id, parent_id, group_index): + sig = self.add.s(1, 1) + sig.freeze( + _id=_id, + group_id=group_id, + chord=chord, + root_id=root_id, + parent_id=parent_id, + group_index=group_index, + ) + options = sig.options + + assert options['task_id'] == _id + assert options['group_id'] == group_id + assert options['chord'] == chord + assert options['root_id'] == root_id + assert options['parent_id'] == parent_id + assert options['group_index'] == group_index + class test_xmap_xstarmap(CanvasCase): @@ -753,6 +798,22 @@ def test_chord_to_group(self): ['x0y0', 'x1y1', 'foo', 'z'] ] + def test_chain_of_chord__or__group_of_single_task(self): + c = chord([signature('header')], signature('body')) + c = chain(c) + g = group(signature('t')) + new_chain = c | g # g should be chained with the body of c[0] + assert isinstance(new_chain, _chain) + assert isinstance(new_chain.tasks[0].body, _chain) + + def test_chain_of_chord_upgrade_on_chaining(self): + c = chord([signature('header')], group(signature('body'))) + c = chain(c) + t = signature('t') + new_chain = c | t # t should be chained with the body of c[0] and create a new chord + assert isinstance(new_chain, _chain) + assert isinstance(new_chain.tasks[0].body, chord) + def test_apply_options(self): class static(Signature): @@ -951,6 +1012,30 @@ def test_chain_single_child_group_result(self): mock_apply.assert_called_once_with(chain=[]) assert res is mock_apply.return_value + def test_chain_flattening_keep_links_of_inner_chain(self): + def link_chain(sig): + sig.link(signature('link_b')) + sig.link_error(signature('link_ab')) + return sig + + inner_chain = link_chain(chain(signature('a'), signature('b'))) + assert inner_chain.options['link'][0] == signature('link_b') + assert inner_chain.options['link_error'][0] == signature('link_ab') + assert inner_chain.tasks[0] == signature('a') + assert inner_chain.tasks[0].options == {} + assert inner_chain.tasks[1] == signature('b') + assert inner_chain.tasks[1].options == {} + + flat_chain = chain(inner_chain, signature('c')) + assert flat_chain.options == {} + assert flat_chain.tasks[0].name == 'a' + assert 'link' not in flat_chain.tasks[0].options + assert signature(flat_chain.tasks[0].options['link_error'][0]) == signature('link_ab') + assert flat_chain.tasks[1].name == 'b' + assert 'link' in flat_chain.tasks[1].options, "b is missing the link from inner_chain.options['link'][0]" + assert signature(flat_chain.tasks[1].options['link'][0]) == signature('link_b') + assert signature(flat_chain.tasks[1].options['link_error'][0]) == signature('link_ab') + class test_group(CanvasCase): def test_group_stamping_one_level(self, subtests): @@ -1278,6 +1363,10 @@ def test_repr(self): x = group([self.add.s(2, 2), self.add.s(4, 4)]) assert repr(x) + def test_repr_empty_group(self): + x = group([]) + assert repr(x) == 'group()' + def test_reverse(self): x = group([self.add.s(2, 2), self.add.s(4, 4)]) assert isinstance(signature(x), group) @@ -1661,6 +1750,19 @@ def test_apply_contains_chords_containing_empty_chord(self): # the encapsulated chains - in this case 1 for each child chord mock_set_chord_size.assert_has_calls((call(ANY, 1),) * child_count) + def test_group_prepared(self): + # Using both partial and dict based signatures + sig = group(dict(self.add.s(0)), self.add.s(0)) + _, group_id, root_id = sig._freeze_gid({}) + tasks = sig._prepared(sig.tasks, [42], group_id, root_id, self.app) + + for task, result, group_id in tasks: + assert isinstance(task, Signature) + assert task.args[0] == 42 + assert task.args[1] == 0 + assert isinstance(result, AsyncResult) + assert group_id is not None + class test_chord(CanvasCase): def test_chord_stamping_one_level(self, subtests): @@ -2317,6 +2419,38 @@ def test_flag_allow_error_cb_on_chord_header_various_header_types(self): errback = c.link_error(sig) assert errback == sig + def test_chord__or__group_of_single_task(self): + """ Test chaining a chord to a group of a single task. """ + c = chord([signature('header')], signature('body')) + g = group(signature('t')) + stil_chord = c | g # g should be chained with the body of c + assert isinstance(stil_chord, chord) + assert isinstance(stil_chord.body, _chain) + + def test_chord_upgrade_on_chaining(self): + """ Test that chaining a chord with a group body upgrades to a new chord """ + c = chord([signature('header')], group(signature('body'))) + t = signature('t') + stil_chord = c | t # t should be chained with the body of c and create a new chord + assert isinstance(stil_chord, chord) + assert isinstance(stil_chord.body, chord) + + @pytest.mark.parametrize('header', [ + [signature('s1'), signature('s2')], + group(signature('s1'), signature('s2')) + ]) + @pytest.mark.usefixtures('depends_on_current_app') + def test_link_error_on_chord_header(self, header): + """ Test that link_error on a chord also links the header """ + self.app.conf.task_allow_error_cb_on_chord_header = True + c = chord(header, signature('body')) + err = signature('err') + errback = c.link_error(err) + assert errback == err + for header_task in c.tasks: + assert header_task.options['link_error'] == [err] + assert c.body.options['link_error'] == [err] + class test_maybe_signature(CanvasCase): @@ -2330,3 +2464,63 @@ def test_is_dict(self): def test_when_sig(self): s = self.add.s() assert maybe_signature(s, app=self.app) is s + + +class test_merge_dictionaries(CanvasCase): + + def test_docstring_example(self): + d1 = {'dict': {'a': 1}, 'list': [1, 2], 'tuple': (1, 2)} + d2 = {'dict': {'b': 2}, 'list': [3, 4], 'set': {'a', 'b'}} + _merge_dictionaries(d1, d2) + assert d1 == { + 'dict': {'a': 1, 'b': 2}, + 'list': [1, 2, 3, 4], + 'tuple': (1, 2), + 'set': {'a', 'b'} + } + + @pytest.mark.parametrize('d1,d2,expected_result', [ + ( + {'None': None}, + {'None': None}, + {'None': [None]} + ), + ( + {'None': None}, + {'None': [None]}, + {'None': [[None]]} + ), + ( + {'None': None}, + {'None': 'Not None'}, + {'None': ['Not None']} + ), + ( + {'None': None}, + {'None': ['Not None']}, + {'None': [['Not None']]} + ), + ( + {'None': [None]}, + {'None': None}, + {'None': [None, None]} + ), + ( + {'None': [None]}, + {'None': [None]}, + {'None': [None, None]} + ), + ( + {'None': [None]}, + {'None': 'Not None'}, + {'None': [None, 'Not None']} + ), + ( + {'None': [None]}, + {'None': ['Not None']}, + {'None': [None, 'Not None']} + ), + ]) + def test_none_values(self, d1, d2, expected_result): + _merge_dictionaries(d1, d2) + assert d1 == expected_result diff --git a/t/unit/tasks/test_chord.py b/t/unit/tasks/test_chord.py index c2aad5f894f..0c3ddf19b0b 100644 --- a/t/unit/tasks/test_chord.py +++ b/t/unit/tasks/test_chord.py @@ -20,7 +20,7 @@ def __eq__(self, other): class ChordCase: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): @@ -323,7 +323,7 @@ def sumX(n): class test_add_to_chord: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): diff --git a/t/unit/tasks/test_result.py b/t/unit/tasks/test_result.py index 6b288e9c557..818409c97d9 100644 --- a/t/unit/tasks/test_result.py +++ b/t/unit/tasks/test_result.py @@ -63,7 +63,7 @@ def remove_pending_result(self, *args, **kwargs): class test_AsyncResult: - def setup(self): + def setup_method(self): self.app.conf.result_cache_max = 100 self.app.conf.result_serializer = 'pickle' self.app.conf.result_extended = True @@ -628,7 +628,7 @@ def get_many(self, *args, **kwargs): class test_GroupResult: - def setup(self): + def setup_method(self): self.size = 10 self.ts = self.app.GroupResult( uuid(), make_mock_group(self.app, self.size), @@ -882,7 +882,7 @@ def test_result(self, app): class test_failed_AsyncResult: - def setup(self): + def setup_method(self): self.size = 11 self.app.conf.result_serializer = 'pickle' results = make_mock_group(self.app, 10) @@ -907,7 +907,7 @@ def test_failed(self): class test_pending_Group: - def setup(self): + def setup_method(self): self.ts = self.app.GroupResult( uuid(), [self.app.AsyncResult(uuid()), self.app.AsyncResult(uuid())]) @@ -932,7 +932,7 @@ def test_join_longer(self): class test_EagerResult: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def raising(x, y): raise KeyError(x, y) diff --git a/t/unit/tasks/test_tasks.py b/t/unit/tasks/test_tasks.py index 2a5f08d6c4f..a636eac73be 100644 --- a/t/unit/tasks/test_tasks.py +++ b/t/unit/tasks/test_tasks.py @@ -60,7 +60,7 @@ class TaskWithRetryButForTypeError(Task): class TasksCase: - def setup(self): + def setup_method(self): self.mytask = self.app.task(shared=False)(return_True) @self.app.task(bind=True, count=0, shared=False) diff --git a/t/unit/tasks/test_trace.py b/t/unit/tasks/test_trace.py index 60fa253dda3..e7767a979f5 100644 --- a/t/unit/tasks/test_trace.py +++ b/t/unit/tasks/test_trace.py @@ -28,7 +28,7 @@ def trace( class TraceCase: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y diff --git a/t/unit/utils/test_collections.py b/t/unit/utils/test_collections.py index aae685ebc7c..79ccc011741 100644 --- a/t/unit/utils/test_collections.py +++ b/t/unit/utils/test_collections.py @@ -52,7 +52,7 @@ def test_items(self): class test_ConfigurationView: - def setup(self): + def setup_method(self): self.view = ConfigurationView( {'changed_key': 1, 'both': 2}, [ diff --git a/t/unit/utils/test_functional.py b/t/unit/utils/test_functional.py index 57055a14a6e..9b9ec087e06 100644 --- a/t/unit/utils/test_functional.py +++ b/t/unit/utils/test_functional.py @@ -1,7 +1,7 @@ import collections import pytest -import pytest_subtests # noqa: F401 +import pytest_subtests from kombu.utils.functional import lazy from celery.utils.functional import (DummyContext, first, firstmethod, fun_accepts_kwargs, fun_takes_argument, diff --git a/t/unit/worker/test_autoscale.py b/t/unit/worker/test_autoscale.py index f6c63c57ac3..c4a2a75ed73 100644 --- a/t/unit/worker/test_autoscale.py +++ b/t/unit/worker/test_autoscale.py @@ -73,7 +73,7 @@ def test_info_without_event_loop(self): class test_Autoscaler: - def setup(self): + def setup_method(self): self.pool = MockPool(3) def test_stop(self): diff --git a/t/unit/worker/test_bootsteps.py b/t/unit/worker/test_bootsteps.py index cb1e91f77be..4a33f44da35 100644 --- a/t/unit/worker/test_bootsteps.py +++ b/t/unit/worker/test_bootsteps.py @@ -56,7 +56,7 @@ class test_Step: class Def(bootsteps.StartStopStep): name = 'test_Step.Def' - def setup(self): + def setup_method(self): self.steps = [] def test_blueprint_name(self, bp='test_blueprint_name'): @@ -162,7 +162,7 @@ class test_StartStopStep: class Def(bootsteps.StartStopStep): name = 'test_StartStopStep.Def' - def setup(self): + def setup_method(self): self.steps = [] def test_start__stop(self): diff --git a/t/unit/worker/test_components.py b/t/unit/worker/test_components.py index 14869cf6df7..739808e4311 100644 --- a/t/unit/worker/test_components.py +++ b/t/unit/worker/test_components.py @@ -22,7 +22,7 @@ def test_create__eventloop(self): class test_Hub: - def setup(self): + def setup_method(self): self.w = Mock(name='w') self.hub = Hub(self.w) self.w.hub = Mock(name='w.hub') diff --git a/t/unit/worker/test_consumer.py b/t/unit/worker/test_consumer.py index 7865cc3ac77..707f6db4302 100644 --- a/t/unit/worker/test_consumer.py +++ b/t/unit/worker/test_consumer.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, Mock, call, patch import pytest +from amqp import ChannelError from billiard.exceptions import RestartFreqExceeded from celery import bootsteps @@ -41,7 +42,7 @@ def get_consumer(self, no_hub=False, **kwargs): class test_Consumer(ConsumerTestCase): - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y @@ -310,6 +311,31 @@ def test_blueprint_restart_when_state_not_in_stop_conditions(self, broker_connec c.start() c.blueprint.restart.assert_called_once() + @pytest.mark.parametrize("broker_channel_error_retry", [True, False]) + def test_blueprint_restart_for_channel_errors(self, broker_channel_error_retry): + c = self.get_consumer() + + # ensure that WorkerShutdown is not raised + c.app.conf['broker_connection_retry'] = True + c.app.conf['broker_connection_retry_on_startup'] = True + c.app.conf['broker_channel_error_retry'] = broker_channel_error_retry + c.restart_count = -1 + + # ensure that blueprint state is not in stop conditions + c.blueprint.state = bootsteps.RUN + c.blueprint.start.side_effect = ChannelError() + + # stops test from running indefinitely in the while loop + c.blueprint.restart.side_effect = self._closer(c) + + # restarted only when broker_channel_error_retry is True + if broker_channel_error_retry: + c.start() + c.blueprint.restart.assert_called_once() + else: + with pytest.raises(ChannelError): + c.start() + def test_collects_at_restart(self): c = self.get_consumer() c.connection.collect.side_effect = MemoryError() diff --git a/t/unit/worker/test_control.py b/t/unit/worker/test_control.py index 33cc521cb5c..a1761a1cb01 100644 --- a/t/unit/worker/test_control.py +++ b/t/unit/worker/test_control.py @@ -116,7 +116,7 @@ def se(*args, **kwargs): class test_ControlPanel: - def setup(self): + def setup_method(self): self.panel = self.create_panel(consumer=Consumer(self.app)) @self.app.task(name='c.unittest.mytask', rate_limit=200, shared=False) diff --git a/t/unit/worker/test_loops.py b/t/unit/worker/test_loops.py index 8a1fe63e4a0..68e84562b4c 100644 --- a/t/unit/worker/test_loops.py +++ b/t/unit/worker/test_loops.py @@ -133,7 +133,7 @@ def get_task_callback(*args, **kwargs): class test_asynloop: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y @@ -529,7 +529,7 @@ def drain_events(timeout): class test_quick_drain: - def setup(self): + def setup_method(self): self.connection = Mock(name='connection') def test_drain(self): diff --git a/t/unit/worker/test_request.py b/t/unit/worker/test_request.py index b818f2837cc..bd63561f0cc 100644 --- a/t/unit/worker/test_request.py +++ b/t/unit/worker/test_request.py @@ -26,7 +26,7 @@ class RequestCase: - def setup(self): + def setup_method(self): self.app.conf.result_serializer = 'pickle' @self.app.task(shared=False) @@ -155,7 +155,7 @@ def test_execute_jail_failure(self): self.app, uuid(), self.mytask_raising.name, {}, [4], {}, ) assert isinstance(ret, ExceptionInfo) - assert ret.exception.exc.args == (4,) + assert ret.exception.args == (4,) def test_execute_task_ignore_result(self): @self.app.task(shared=False, ignore_result=True) @@ -1173,11 +1173,11 @@ def test_group_index(self): class test_create_request_class(RequestCase): - def setup(self): + def setup_method(self): self.task = Mock(name='task') self.pool = Mock(name='pool') self.eventer = Mock(name='eventer') - super().setup() + super().setup_method() def create_request_cls(self, **kwargs): return create_request_cls( diff --git a/t/unit/worker/test_state.py b/t/unit/worker/test_state.py index bdff94facbf..cf67aa25957 100644 --- a/t/unit/worker/test_state.py +++ b/t/unit/worker/test_state.py @@ -45,7 +45,7 @@ class MyPersistent(state.Persistent): class test_maybe_shutdown: - def teardown(self): + def teardown_method(self): state.should_stop = None state.should_terminate = None diff --git a/t/unit/worker/test_strategy.py b/t/unit/worker/test_strategy.py index 8d7098954af..366d5c62081 100644 --- a/t/unit/worker/test_strategy.py +++ b/t/unit/worker/test_strategy.py @@ -18,7 +18,7 @@ class test_proto1_to_proto2: - def setup(self): + def setup_method(self): self.message = Mock(name='message') self.body = { 'args': (1,), @@ -58,7 +58,7 @@ def test_message(self): class test_default_strategy_proto2: - def setup(self): + def setup_method(self): @self.app.task(shared=False) def add(x, y): return x + y @@ -301,7 +301,7 @@ def failed(): class test_hybrid_to_proto2: - def setup(self): + def setup_method(self): self.message = Mock(name='message', headers={"custom": "header"}) self.body = { 'args': (1,), diff --git a/t/unit/worker/test_worker.py b/t/unit/worker/test_worker.py index 6bf2a14a1d6..cfa67440b4c 100644 --- a/t/unit/worker/test_worker.py +++ b/t/unit/worker/test_worker.py @@ -77,7 +77,7 @@ def create_task_message(self, channel, *args, **kwargs): class test_Consumer(ConsumerCase): - def setup(self): + def setup_method(self): self.buffer = FastQueue() self.timer = Timer() @@ -86,7 +86,7 @@ def foo_task(x, y, z): return x * y * z self.foo_task = foo_task - def teardown(self): + def teardown_method(self): self.timer.stop() def LoopConsumer(self, buffer=None, controller=None, timer=None, app=None, @@ -697,7 +697,7 @@ def test_reset_connection_with_no_node(self): class test_WorkController(ConsumerCase): - def setup(self): + def setup_method(self): self.worker = self.create_worker() self._logger = worker_module.logger self._comp_logger = components.logger @@ -709,7 +709,7 @@ def foo_task(x, y, z): return x * y * z self.foo_task = foo_task - def teardown(self): + def teardown_method(self): worker_module.logger = self._logger components.logger = self._comp_logger