Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion taskiq/kicker.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ def with_labels(
self.labels.update(labels)
return self

def with_task_id(self, task_id: str) -> "AsyncKicker[_FuncParams, _ReturnType]":
def with_task_id(
self,
task_id: Optional[str],
) -> "AsyncKicker[_FuncParams, _ReturnType]":
"""
Set task_id for current execution.

Expand Down Expand Up @@ -208,6 +211,7 @@ async def schedule_by_cron(
labels=message.labels,
args=message.args,
kwargs=message.kwargs,
task_id=self.custom_task_id,
cron=cron_str,
cron_offset=cron_offset,
)
Expand Down Expand Up @@ -239,6 +243,7 @@ async def schedule_by_time(
labels=message.labels,
args=message.args,
kwargs=message.kwargs,
task_id=self.custom_task_id,
time=time,
)
await source.add_schedule(scheduled)
Expand Down
1 change: 1 addition & 0 deletions taskiq/scheduler/scheduled_task/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class ScheduledTask(BaseModel):
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]
task_id: Optional[str] = None
schedule_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
cron: Optional[str] = None
cron_offset: Optional[Union[str, timedelta]] = None
Expand Down
1 change: 1 addition & 0 deletions taskiq/scheduler/scheduled_task/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class ScheduledTask(BaseModel):
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]
task_id: Optional[str] = None
schedule_id: str = Field(default_factory=lambda: uuid.uuid4().hex)
cron: Optional[str] = None
cron_offset: Optional[Union[str, timedelta]] = None
Expand Down
1 change: 1 addition & 0 deletions taskiq/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None:
.with_labels(
schedule_id=task.schedule_id,
)
.with_task_id(task_id=task.task_id)
.kiq(
*task.args,
**task.kwargs,
Expand Down
43 changes: 43 additions & 0 deletions tests/test_retry_task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest

from taskiq import (
Context,
InMemoryBroker,
SmartRetryMiddleware,
TaskiqDepends,
TaskiqScheduler,
)
from taskiq.schedule_sources import LabelScheduleSource


@pytest.mark.parametrize(
"retry_count",
range(5),
)
@pytest.mark.anyio
async def test_save_task_id_for_retry(retry_count: int) -> None:
broker = InMemoryBroker().with_middlewares(
SmartRetryMiddleware(
default_retry_count=retry_count + 1,
default_delay=0.1,
),
)
scheduler = TaskiqScheduler(broker, [LabelScheduleSource(broker)])

check_interval = 0.5

@broker.task("exc_task", retry_on_error=True)
async def exc_task(count: int = 0, context: "Context" = TaskiqDepends()) -> int:
retry = int(context.message.labels.get("_retries", 0))
if retry < count:
raise Exception("test")
return retry

await broker.startup()
await scheduler.startup()

task_with_retry = await exc_task.kiq(retry_count)
task_with_retry_result = await task_with_retry.wait_result(
check_interval=check_interval,
)
assert task_with_retry_result.return_value == retry_count
Loading