Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 201 additions & 3 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ pycron = "^3.0.0"
taskiq_dependencies = ">=1.3.1,<2"
anyio = ">=4"
packaging = ">=19"
# For opentelemetry instrumentation
opentelemetry-api = { version = "^1.38.0", optional = true }
opentelemetry-instrumentation = { version = "^0.59b0", optional = true}
opentelemetry-semantic-conventions = { version = "^0.59b0", optional = true}
# For prometheus metrics
prometheus_client = { version = "^0", optional = true }
# For ZMQBroker
Expand Down Expand Up @@ -69,10 +73,12 @@ pytest-mock = "^3.11.1"
tzlocal = "^5.0.1"
types-tzlocal = "^5.0.1.1"
types-pytz = "^2023.3.1.1"
opentelemetry-test-utils = "^0.59b0"

[tool.poetry.extras]
zmq = ["pyzmq"]
uv = ["uvloop"]
opentelemetry = ["opentelemetry-api", "opentelemetry-instrumentation", "opentelemetry-semantic-conventions"]
metrics = ["prometheus_client"]
reload = ["watchdog", "gitignore-parser"]
orjson = ["orjson"]
Expand All @@ -86,6 +92,9 @@ taskiq = "taskiq.__main__:main"
worker = "taskiq.cli.worker.cmd:WorkerCMD"
scheduler = "taskiq.cli.scheduler.cmd:SchedulerCMD"

[tool.poetry.plugins.opentelemetry_instrumentor]
taskiq = "taskiq.instrumentation:TaskiqInstrumentor"

[tool.mypy]
strict = true
ignore_missing_imports = true
Expand Down
157 changes: 157 additions & 0 deletions taskiq/instrumentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Instrument `taskiq`_ to trace Taskiq applications.

.. _taskiq: https://pypi.org/project/taskiq/

Usage
-----

* Run instrumented task

.. code:: python

import asyncio

from taskiq import InMemoryBroker, TaskiqEvents, TaskiqState
from taskiq.instrumentation import TaskiqInstrumentor

broker = InMemoryBroker()

@broker.task
async def add(x, y):
return x + y

async def main():
TaskiqInstrumentor().instrument()
await broker.startup()
await my_task.kiq(1, 2)
await broker.shutdown()

if __name__ == "__main__":
asyncio.run(main())

API
---
"""


import logging
from functools import partial
from typing import Any, Callable, Collection, Optional
from weakref import WeakSet as _WeakSet

from taskiq.cli.worker.args import WorkerArgs

try:
import opentelemetry # noqa: F401
except ImportError as exc:
raise ImportError(
"Cannot instrument. Please install 'taskiq[opentelemetry]'.",
) from exc


from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined]
BaseInstrumentor,
)
from opentelemetry.instrumentation.utils import unwrap
from opentelemetry.metrics import MeterProvider
from opentelemetry.trace import TracerProvider
from wrapt import wrap_function_wrapper, wrap_object_attribute

from taskiq import AsyncBroker
from taskiq.cli.worker.process_manager import ProcessManager
from taskiq.middlewares.opentelemetry_middleware import OpenTelemetryMiddleware

logger = logging.getLogger("taskiq.opentelemetry")


def _worker_function_with_sitecustomize(
worker_function: Callable[[WorkerArgs], None],
*args: Any,
**kwargs: Any,
) -> None:
import opentelemetry.instrumentation.auto_instrumentation.sitecustomize # noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import opentelemetry.instrumentation.auto_instrumentation.sitecustomize # noqa
from opentelemetry.instrumentation.auto_instrumentation import initialize
initialize()


return worker_function(*args, **kwargs)


def _worker_function_factory(
worker_function: Callable[[WorkerArgs], None],
) -> Callable[[WorkerArgs], None]:
return partial(_worker_function_with_sitecustomize, worker_function)


class TaskiqInstrumentor(BaseInstrumentor):
"""OpenTelemetry instrumentor for Taskiq."""

_instrumented_brokers: _WeakSet[AsyncBroker] = _WeakSet()

def __init__(self) -> None:
super().__init__()
self._middleware = None

def instrument_broker(
self,
broker: AsyncBroker,
tracer_provider: Optional[TracerProvider] = None,
meter_provider: Optional[MeterProvider] = None,
) -> None:
"""Instrument broker."""
if not hasattr(broker, "_is_instrumented_by_opentelemetry"):
broker._is_instrumented_by_opentelemetry = False # type: ignore[attr-defined] # noqa: SLF001

if not getattr(broker, "is_instrumented_by_opentelemetry", False):
broker.middlewares.insert(
0,
OpenTelemetryMiddleware(
tracer_provider=tracer_provider,
meter_provider=meter_provider,
),
)
broker._is_instrumented_by_opentelemetry = True # type: ignore[attr-defined] # noqa: SLF001
if broker not in self._instrumented_brokers:
self._instrumented_brokers.add(broker)
else:
logger.warning(
"Attempting to instrument taskiq broker while already instrumented",
)

def uninstrument_broker(self, broker: AsyncBroker) -> None:
"""Uninstrument broker."""
broker.middlewares = [
middleware
for middleware in broker.middlewares
if not isinstance(middleware, OpenTelemetryMiddleware)
]
broker._is_instrumented_by_opentelemetry = False # type: ignore[attr-defined] # noqa: SLF001
self._instrumented_brokers.discard(broker)

def instrumentation_dependencies(self) -> Collection[str]:
"""This function tells which library this instrumentor instruments."""
return ("taskiq >= 0.0.1",)

def _instrument(self, **kwargs: Any) -> None:
def broker_init(
init: Callable[[Any], Any],
broker: AsyncBroker,
args: Any,
kwargs: Any,
) -> None:
result = init(*args, **kwargs)
self.instrument_broker(broker)
return result

wrap_function_wrapper("taskiq.abc.broker", "AsyncBroker.__init__", broker_init)
wrap_object_attribute(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I guess it might be simpler to wrap taskiq.cli.worker.run:start_listen. At least it won't require any partial.

"taskiq.cli.worker.process_manager",
"ProcessManager.worker_function",
_worker_function_factory,
)

def _uninstrument(self, **kwargs: Any) -> None:
instances_to_uninstrument = list(self._instrumented_brokers)
for broker in instances_to_uninstrument:
self.uninstrument_broker(broker)
self._instrumented_brokers.clear()
unwrap(AsyncBroker, "__init__")
delattr(ProcessManager, "worker_function")
2 changes: 1 addition & 1 deletion taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ async def kiq(
except Exception as exc:
raise SendTaskError from exc

for middleware in self.broker.middlewares:
for middleware in reversed(self.broker.middlewares):
if middleware.__class__.post_send != TaskiqMiddleware.post_send:
await maybe_awaitable(middleware.post_send(message))

Expand Down
Loading