Skip to content

Commit 6383e01

Browse files
committed
Switch to broker.serializer instead of JSON-only solution
1 parent 4c3f0aa commit 6383e01

File tree

7 files changed

+34
-89
lines changed

7 files changed

+34
-89
lines changed

taskiq_pipelines/abc.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,6 @@ def __init_subclass__(cls, step_name: str, **kwargs: Any) -> None:
2020
# known steps.
2121
cls._known_steps[step_name] = cls
2222

23-
@abstractmethod
24-
def dumps(self) -> str:
25-
"""
26-
Generate parsable string.
27-
28-
:return: dumped object.
29-
"""
30-
31-
@classmethod
32-
@abstractmethod
33-
def loads(cls: Type[_T], data: str) -> _T:
34-
"""
35-
Method to load previously dumped data.
36-
37-
:param data: dumped data.
38-
:return: instance of a class.
39-
"""
40-
4123
@abstractmethod
4224
async def act(
4325
self,

taskiq_pipelines/middleware.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class PipelineMiddleware(TaskiqMiddleware):
1616
"""Pipeline middleware."""
1717

18-
async def post_save( # noqa: C901, WPS212
18+
async def post_save( # noqa: C901, WPS210, WPS212
1919
self,
2020
message: "TaskiqMessage",
2121
result: "TaskiqResult[Any]",
@@ -41,9 +41,21 @@ async def post_save( # noqa: C901, WPS212
4141
logger.warn("Pipline data not found. Execution flow is broken.")
4242
return
4343
pipeline_data = message.labels[PIPELINE_DATA]
44+
# workaround for obligatory casting label values to `str`
45+
# in `AsyncKicker._prepare_message`.
46+
# The trick can be removed later after adding explicit `bytes` support.
47+
if ( # noqa: WPS337
48+
isinstance(pipeline_data, str)
49+
and pipeline_data.startswith("b'")
50+
and pipeline_data.endswith("'")
51+
):
52+
pipeline_data2 = pipeline_data[2:-1].encode()
53+
else:
54+
pipeline_data2 = pipeline_data
55+
parsed_data = self.broker.serializer.loadb(pipeline_data2)
4456
try:
45-
steps_data = pydantic.TypeAdapter(List[DumpedStep]).validate_json(
46-
pipeline_data,
57+
steps_data = pydantic.TypeAdapter(List[DumpedStep]).validate_python(
58+
parsed_data,
4759
)
4860
except ValueError:
4961
return

taskiq_pipelines/pipeliner.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
import json
21
from typing import (
32
Any,
43
Coroutine,
4+
Dict,
55
Generic,
66
List,
77
Literal,
@@ -29,10 +29,13 @@ class DumpedStep(pydantic.BaseModel):
2929
"""Dumped state model."""
3030

3131
step_type: str
32-
step_data: str
32+
step_data: Dict[str, Any]
3333
task_id: str
3434

3535

36+
DumpedSteps = pydantic.RootModel[List[DumpedStep]]
37+
38+
3639
class Pipeline(Generic[_FuncParams, _ReturnType]):
3740
"""
3841
Pipeline constructor.
@@ -116,7 +119,7 @@ def call_next(
116119
task=task,
117120
param_name=param_name,
118121
**additional_kwargs,
119-
).dumps(),
122+
).model_dump(),
120123
task_id="",
121124
),
122125
)
@@ -172,7 +175,7 @@ def call_after(
172175
task=task,
173176
param_name=EMPTY_PARAM_NAME,
174177
**additional_kwargs,
175-
).dumps(),
178+
).model_dump(),
176179
task_id="",
177180
),
178181
)
@@ -243,7 +246,7 @@ def map(
243246
skip_errors=skip_errors,
244247
check_interval=check_interval,
245248
**additional_kwargs,
246-
).dumps(),
249+
).model_dump(),
247250
task_id="",
248251
),
249252
)
@@ -315,24 +318,24 @@ def filter(
315318
skip_errors=skip_errors,
316319
check_interval=check_interval,
317320
**additional_kwargs,
318-
).dumps(),
321+
).model_dump(),
319322
task_id="",
320323
),
321324
)
322325
return self
323326

324-
def dumps(self) -> str:
327+
def dumpb(self) -> bytes:
325328
"""
326329
Dumps current pipeline as string.
327330
328331
:returns: serialized pipeline.
329332
"""
330-
return json.dumps(
331-
[step.model_dump() for step in self.steps],
333+
return self.broker.serializer.dumpb(
334+
DumpedSteps.model_validate(self.steps).model_dump(),
332335
)
333336

334337
@classmethod
335-
def loads(cls, broker: AsyncBroker, pipe_data: str) -> "Pipeline[Any, Any]":
338+
def loadb(cls, broker: AsyncBroker, pipe_data: bytes) -> "Pipeline[Any, Any]":
336339
"""
337340
Parses serialized pipeline.
338341
@@ -344,7 +347,8 @@ def loads(cls, broker: AsyncBroker, pipe_data: str) -> "Pipeline[Any, Any]":
344347
:return: new
345348
"""
346349
pipe: "Pipeline[Any, Any]" = Pipeline(broker)
347-
pipe.steps = pydantic.TypeAdapter(List[DumpedStep]).validate_json(pipe_data)
350+
data = broker.serializer.loadb(pipe_data)
351+
pipe.steps = DumpedSteps.model_validate(data) # type: ignore[assignment]
348352
return pipe
349353

350354
async def kiq(
@@ -383,7 +387,7 @@ async def kiq(
383387
)
384388
.with_task_id(step.task_id)
385389
.with_labels(
386-
**{CURRENT_STEP: 0, PIPELINE_DATA: self.dumps()}, # type: ignore
390+
**{CURRENT_STEP: 0, PIPELINE_DATA: self.dumpb()}, # type: ignore
387391
)
388392
)
389393
taskiq_task = await kicker.kiq(*args, **kwargs)

taskiq_pipelines/steps/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Package with default pipeline steps."""
22
from logging import getLogger
3+
from typing import Any, Dict
34

45
from taskiq_pipelines.abc import AbstractStep
56
from taskiq_pipelines.steps.filter import FilterStep
@@ -9,12 +10,12 @@
910
logger = getLogger(__name__)
1011

1112

12-
def parse_step(step_type: str, step_data: str) -> AbstractStep:
13+
def parse_step(step_type: str, step_data: Dict[str, Any]) -> AbstractStep:
1314
step_cls = AbstractStep._known_steps.get(step_type) # noqa: WPS437
1415
if step_cls is None:
1516
logger.warning(f"Unknown step type: {step_type}")
1617
raise ValueError("Unknown step type.")
17-
return step_cls.loads(step_data)
18+
return step_cls(**step_data)
1819

1920

2021
__all__ = [

taskiq_pipelines/steps/filter.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -78,24 +78,6 @@ class FilterStep(pydantic.BaseModel, AbstractStep, step_name="filter"):
7878
skip_errors: bool
7979
check_interval: float
8080

81-
def dumps(self) -> str:
82-
"""
83-
Dumps step as string.
84-
85-
:return: returns json.
86-
"""
87-
return self.model_dump_json()
88-
89-
@classmethod
90-
def loads(cls, data: str) -> "FilterStep":
91-
"""
92-
Parses mapper step from string.
93-
94-
:param data: dumped data.
95-
:return: parsed step.
96-
"""
97-
return pydantic.TypeAdapter(FilterStep).validate_json(data)
98-
9981
async def act(
10082
self,
10183
broker: AsyncBroker,

taskiq_pipelines/steps/mapper.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,6 @@ class MapperStep(pydantic.BaseModel, AbstractStep, step_name="mapper"):
7575
skip_errors: bool
7676
check_interval: float
7777

78-
def dumps(self) -> str:
79-
"""
80-
Dumps step as string.
81-
82-
:return: returns json.
83-
"""
84-
return self.model_dump_json()
85-
86-
@classmethod
87-
def loads(cls, data: str) -> "MapperStep":
88-
"""
89-
Parses mapper step from string.
90-
91-
:param data: dumped data.
92-
:return: parsed step.
93-
"""
94-
return pydantic.TypeAdapter(MapperStep).validate_json(data)
95-
9678
async def act(
9779
self,
9880
broker: AsyncBroker,

taskiq_pipelines/steps/sequential.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,6 @@ class SequentialStep(pydantic.BaseModel, AbstractStep, step_name="sequential"):
2424
param_name: Union[Optional[int], str]
2525
additional_kwargs: Dict[str, Any]
2626

27-
def dumps(self) -> str:
28-
"""
29-
Dumps step as string.
30-
31-
:return: returns json.
32-
"""
33-
return self.model_dump_json()
34-
35-
@classmethod
36-
def loads(cls, data: str) -> "SequentialStep":
37-
"""
38-
Parses sequential step from string.
39-
40-
:param data: dumped data.
41-
:return: parsed step.
42-
"""
43-
return pydantic.TypeAdapter(SequentialStep).validate_json(data)
44-
4527
async def act(
4628
self,
4729
broker: AsyncBroker,

0 commit comments

Comments
 (0)