Skip to content

Commit 7cffe78

Browse files
authored
Merge branch 'master' into go-to-uv
2 parents c5a2f88 + 5c434f6 commit 7cffe78

File tree

5 files changed

+61
-28
lines changed

5 files changed

+61
-28
lines changed

taskiq/instrumentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def uninstrument_broker(self, broker: AsyncBroker) -> None:
115115

116116
def instrumentation_dependencies(self) -> Collection[str]:
117117
"""This function tells which library this instrumentor instruments."""
118-
return ("taskiq >= 0.0.1",)
118+
return ("taskiq >= 0.0.0",)
119119

120120
@classmethod
121121
def _start_listen_with_initialize(cls, args: WorkerArgs) -> None:

taskiq/middlewares/opentelemetry_middleware.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import logging
22
from contextlib import AbstractContextManager
3-
from typing import Any, TypeVar
3+
from importlib.metadata import version
4+
from typing import Any, Dict, Optional, Tuple, TypeVar
5+
6+
from packaging.version import Version, parse
47

58
try:
69
import opentelemetry # noqa: F401
@@ -10,6 +13,7 @@
1013
"Please install 'taskiq[opentelemetry]'.",
1114
) from exc
1215

16+
1317
from opentelemetry import context as context_api
1418
from opentelemetry import trace
1519
from opentelemetry.metrics import Meter, MeterProvider, get_meter
@@ -27,6 +31,16 @@
2731
# Taskiq Context key
2832
CTX_KEY = "__otel_task_span"
2933

34+
# unlike pydantic v2, v1 includes CTX_KEY by default
35+
# excluding it here
36+
PYDANTIC_VER = parse(version("pydantic"))
37+
IS_PYDANTIC1 = Version("2.0") > PYDANTIC_VER
38+
if IS_PYDANTIC1:
39+
if TaskiqMessage.__exclude_fields__: # type: ignore[attr-defined]
40+
TaskiqMessage.__exclude_fields__.update(CTX_KEY) # type: ignore
41+
else:
42+
TaskiqMessage.__exclude_fields__ = {CTX_KEY} # type: ignore
43+
3044
# Taskiq Context attributes
3145
TASKIQ_CONTEXT_ATTRIBUTES = [
3246
"_retries",
@@ -95,7 +109,9 @@ def attach_context(
95109

96110
if ctx_dict is None:
97111
ctx_dict = {}
98-
setattr(message, CTX_KEY, ctx_dict)
112+
# use object.__setattr__ directly
113+
# to skip pydantic v1 setattr
114+
object.__setattr__(message, CTX_KEY, ctx_dict)
99115

100116
ctx_dict[(message.task_id, is_publish)] = (span, activation, token)
101117

tests/api/test_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ async def test_successful() -> None:
1616
@broker.task(schedule=[{"time": datetime.now(timezone.utc) - timedelta(seconds=1)}])
1717
def _() -> None: ...
1818

19-
msg = await asyncio.wait_for(broker.queue.get(), 1)
19+
msg = await asyncio.wait_for(broker.queue.get(), 2)
2020
assert msg
2121

2222
scheduler_task.cancel()

tests/opentelemetry/test_auto_instrumentation.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
13
from opentelemetry.test.test_base import TestBase
24
from opentelemetry.trace import SpanKind, StatusCode
35

@@ -6,17 +8,20 @@
68

79

810
class TestTaskiqAutoInstrumentation(TestBase):
9-
async def test_auto_instrument(self) -> None:
11+
def test_auto_instrument(self) -> None:
1012
TaskiqInstrumentor().instrument()
1113

12-
broker = InMemoryBroker(await_inplace=True)
14+
async def test() -> None:
15+
broker = InMemoryBroker(await_inplace=True)
16+
17+
@broker.task
18+
async def task_add(a: float, b: float) -> float:
19+
return a + b
1320

14-
@broker.task
15-
async def task_add(a: float, b: float) -> float:
16-
return a + b
21+
await task_add.kiq(1, 2)
22+
await broker.wait_all()
1723

18-
await task_add.kiq(1, 2)
19-
await broker.wait_all()
24+
asyncio.run(test())
2025

2126
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
2227
self.assertEqual(len(spans), 2)
@@ -25,15 +30,15 @@ async def task_add(a: float, b: float) -> float:
2530

2631
self.assertEqual(
2732
consumer.name,
28-
"execute/tests.test_auto_instrumentation:task_add",
33+
"execute/tests.opentelemetry.test_auto_instrumentation:task_add",
2934
f"{consumer._end_time}:{producer._end_time}",
3035
)
3136
self.assertEqual(consumer.kind, SpanKind.CONSUMER)
3237
self.assertSpanHasAttributes(
3338
consumer,
3439
{
3540
"taskiq.action": "execute",
36-
"taskiq.task_name": "tests.test_auto_instrumentation:task_add",
41+
"taskiq.task_name": "tests.opentelemetry.test_auto_instrumentation:task_add", # noqa: E501
3742
},
3843
)
3944

@@ -43,14 +48,14 @@ async def task_add(a: float, b: float) -> float:
4348

4449
self.assertEqual(
4550
producer.name,
46-
"send/tests.test_auto_instrumentation:task_add",
51+
"send/tests.opentelemetry.test_auto_instrumentation:task_add",
4752
)
4853
self.assertEqual(producer.kind, SpanKind.PRODUCER)
4954
self.assertSpanHasAttributes(
5055
producer,
5156
{
5257
"taskiq.action": "send",
53-
"taskiq.task_name": "tests.test_auto_instrumentation:task_add",
58+
"taskiq.task_name": "tests.opentelemetry.test_auto_instrumentation:task_add", # noqa: E501
5459
},
5560
)
5661

tests/opentelemetry/test_tasks.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from opentelemetry.trace import Span, SpanKind, StatusCode
1111
from wrapt import wrap_function_wrapper
1212

13+
from taskiq import TaskiqResult
1314
from taskiq.instrumentation import TaskiqInstrumentor
1415
from taskiq.middlewares import opentelemetry_middleware
1516

@@ -26,11 +27,14 @@ def tearDown(self) -> None:
2627
super().tearDown()
2728
TaskiqInstrumentor().uninstrument_broker(broker)
2829

29-
async def test_task(self) -> None:
30+
def test_task(self) -> None:
3031
TaskiqInstrumentor().instrument_broker(broker)
3132

32-
await task_add.kiq(1, 2)
33-
await broker.wait_all()
33+
async def test() -> None:
34+
await task_add.kiq(1, 2)
35+
await broker.wait_all()
36+
37+
asyncio.run(test())
3438

3539
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
3640
self.assertEqual(len(spans), 2)
@@ -72,11 +76,14 @@ async def test_task(self) -> None:
7276
self.assertEqual(consumer.parent.span_id, producer.context.span_id)
7377
self.assertEqual(consumer.context.trace_id, producer.context.trace_id)
7478

75-
async def test_task_raises(self) -> None:
79+
def test_task_raises(self) -> None:
7680
TaskiqInstrumentor().instrument_broker(broker)
7781

78-
await task_raises.kiq()
79-
await broker.wait_all()
82+
async def test() -> None:
83+
await task_raises.kiq()
84+
await broker.wait_all()
85+
86+
asyncio.run(test())
8087

8188
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
8289
self.assertEqual(len(spans), 2)
@@ -130,7 +137,7 @@ async def test_task_raises(self) -> None:
130137
self.assertEqual(consumer.parent.span_id, producer.context.span_id)
131138
self.assertEqual(consumer.context.trace_id, producer.context.trace_id)
132139

133-
async def test_uninstrument(self) -> None:
140+
def test_uninstrument(self) -> None:
134141
TaskiqInstrumentor().instrument_broker(broker)
135142
TaskiqInstrumentor().uninstrument_broker(broker)
136143

@@ -143,18 +150,21 @@ async def test() -> None:
143150
spans = self.memory_exporter.get_finished_spans()
144151
self.assertEqual(len(spans), 0)
145152

146-
async def test_baggage(self) -> None:
153+
def test_baggage(self) -> None:
147154
TaskiqInstrumentor().instrument_broker(broker)
148155

156+
async def test() -> TaskiqResult[Any]:
157+
task = await task_returns_baggage.kiq()
158+
return await task.wait_result(timeout=2)
159+
149160
ctx = baggage.set_baggage("key", "value")
150161
context.attach(ctx)
151162

152-
task = await task_returns_baggage.kiq()
153-
result = await task.wait_result(timeout=2)
163+
result = asyncio.run(test())
154164

155165
self.assertEqual(result.return_value, {"key": "value"})
156166

157-
async def test_task_not_instrumented_does_not_raise(self) -> None:
167+
def test_task_not_instrumented_does_not_raise(self) -> None:
158168
def _retrieve_context_wrapper_none_token(
159169
wrapped: Callable[
160170
[Any],
@@ -178,9 +188,11 @@ def _retrieve_context_wrapper_none_token(
178188

179189
TaskiqInstrumentor().instrument_broker(broker)
180190

181-
task = await task_add.kiq(1, 2)
182-
result = await task.wait_result(timeout=2)
191+
async def test() -> TaskiqResult[float]:
192+
task = await task_add.kiq(1, 2)
193+
return await task.wait_result(timeout=2)
183194

195+
result = asyncio.run(test())
184196
spans = self.sorted_spans(self.memory_exporter.get_finished_spans())
185197
self.assertEqual(len(spans), 2)
186198

0 commit comments

Comments
 (0)