|
| 1 | +import inspect |
| 2 | +import logging |
1 | 3 | import threading |
2 | 4 |
|
3 | 5 | from celery import Celery |
| 6 | +from application.core import log_context |
4 | 7 | from application.core.settings import settings |
5 | | -from celery.signals import setup_logging, worker_process_init, worker_ready |
| 8 | +from celery.signals import ( |
| 9 | + setup_logging, |
| 10 | + task_postrun, |
| 11 | + task_prerun, |
| 12 | + worker_process_init, |
| 13 | + worker_ready, |
| 14 | +) |
6 | 15 |
|
7 | 16 |
|
8 | 17 | def make_celery(app_name=__name__): |
@@ -41,6 +50,54 @@ def _dispose_db_engine_on_fork(*args, **kwargs): |
41 | 50 | dispose_engine() |
42 | 51 |
|
43 | 52 |
|
| 53 | +# Most tasks in this repo accept ``user`` where the log context wants |
| 54 | +# ``user_id``; map task parameter names to context keys explicitly. |
| 55 | +_TASK_PARAM_TO_CTX_KEY: dict[str, str] = { |
| 56 | + "user": "user_id", |
| 57 | + "user_id": "user_id", |
| 58 | + "agent_id": "agent_id", |
| 59 | + "conversation_id": "conversation_id", |
| 60 | +} |
| 61 | + |
| 62 | +_task_log_tokens: dict[str, object] = {} |
| 63 | + |
| 64 | + |
| 65 | +@task_prerun.connect |
| 66 | +def _bind_task_log_context(task_id, task, args, kwargs, **_): |
| 67 | + # Resolve task args by parameter name — nearly every task in this repo |
| 68 | + # is called positionally, so ``kwargs.get('user')`` would bind nothing. |
| 69 | + ctx = {"activity_id": task_id} |
| 70 | + try: |
| 71 | + sig = inspect.signature(task.run) |
| 72 | + bound = sig.bind_partial(*args, **kwargs).arguments |
| 73 | + except (TypeError, ValueError): |
| 74 | + bound = dict(kwargs) |
| 75 | + for param_name, value in bound.items(): |
| 76 | + ctx_key = _TASK_PARAM_TO_CTX_KEY.get(param_name) |
| 77 | + if ctx_key and value: |
| 78 | + ctx[ctx_key] = value |
| 79 | + _task_log_tokens[task_id] = log_context.bind(**ctx) |
| 80 | + |
| 81 | + |
| 82 | +@task_postrun.connect |
| 83 | +def _unbind_task_log_context(task_id, **_): |
| 84 | + # ``task_postrun`` fires on both success and failure. Required for |
| 85 | + # Celery: unlike the Flask path, tasks aren't isolated in their own |
| 86 | + # ``copy_context().run(...)``, so a missing reset would leak the |
| 87 | + # bind onto the next task on the same worker. |
| 88 | + token = _task_log_tokens.pop(task_id, None) |
| 89 | + if token is None: |
| 90 | + return |
| 91 | + try: |
| 92 | + log_context.reset(token) |
| 93 | + except ValueError: |
| 94 | + # task_prerun and task_postrun ran on different threads (non-default |
| 95 | + # Celery pool); the token isn't valid in this context. Drop it. |
| 96 | + logging.getLogger(__name__).debug( |
| 97 | + "log_context reset skipped for task %s", task_id |
| 98 | + ) |
| 99 | + |
| 100 | + |
44 | 101 | @worker_ready.connect |
45 | 102 | def _run_version_check(*args, **kwargs): |
46 | 103 | """Kick off the anonymous version check on worker startup. |
|
0 commit comments