Skip to content

Commit 02c2817

Browse files
authored
Removed source from the TaskiqSchedule. (#218)
1 parent 43af7a9 commit 02c2817

File tree

9 files changed

+53
-48
lines changed

9 files changed

+53
-48
lines changed

docs/examples/extending/schedule_source.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
1919
args=[],
2020
kwargs={},
2121
cron="* * * * *",
22-
#
23-
# We need point on self source for calling pre_send / post_send when
24-
# task is ready to be enqueued.
25-
source=self,
2622
),
2723
]
2824

taskiq/abc/schedule_source.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ async def shutdown(self) -> None: # noqa: B027
1818
async def get_schedules(self) -> List["ScheduledTask"]:
1919
"""Get list of taskiq schedules."""
2020

21-
async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027
21+
async def add_schedule(self, schedule: "ScheduledTask") -> None:
2222
"""
2323
Add a new schedule.
2424
@@ -33,6 +33,9 @@ async def add_schedule(self, schedule: "ScheduledTask") -> None: # noqa: B027
3333
3434
:param schedule: schedule to add.
3535
"""
36+
raise NotImplementedError(
37+
f"The source {self.__class__.__name__} does not support adding schedules.",
38+
)
3639

3740
def pre_send( # noqa: B027
3841
self,

taskiq/cli/scheduler/run.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
import sys
33
from datetime import datetime, timedelta
44
from logging import basicConfig, getLevelName, getLogger
5-
from typing import List, Optional
5+
from typing import Dict, List, Optional
66

77
import pytz
88
from pycron import is_now
99

10+
from taskiq.abc.schedule_source import ScheduleSource
1011
from taskiq.cli.scheduler.args import SchedulerArgs
1112
from taskiq.cli.utils import import_object, import_tasks
1213
from taskiq.scheduler.scheduler import ScheduledTask, TaskiqScheduler
@@ -32,7 +33,7 @@ def to_tz_aware(time: datetime) -> datetime:
3233

3334
async def schedules_updater(
3435
scheduler: TaskiqScheduler,
35-
current_schedules: List[ScheduledTask],
36+
current_schedules: Dict[ScheduleSource, List[ScheduledTask]],
3637
event: asyncio.Event,
3738
) -> None:
3839
"""
@@ -48,7 +49,7 @@ async def schedules_updater(
4849
"""
4950
while True:
5051
logger.debug("Started schedule update.")
51-
new_schedules: "List[ScheduledTask]" = []
52+
new_schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
5253
for source in scheduler.sources:
5354
try:
5455
schedules = await source.get_schedules()
@@ -60,10 +61,13 @@ async def schedules_updater(
6061
logger.debug(exc, exc_info=True)
6162
continue
6263

63-
new_schedules = scheduler.merge_func(new_schedules, schedules)
64+
new_schedules[source] = scheduler.merge_func(
65+
new_schedules.get(source) or [],
66+
schedules,
67+
)
6468

6569
current_schedules.clear()
66-
current_schedules.extend(new_schedules)
70+
current_schedules.update(new_schedules)
6771
event.set()
6872
await asyncio.sleep(scheduler.refresh_delay)
6973

@@ -100,6 +104,7 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:
100104

101105
async def delayed_send(
102106
scheduler: TaskiqScheduler,
107+
source: ScheduleSource,
103108
task: ScheduledTask,
104109
delay: int,
105110
) -> None:
@@ -115,13 +120,14 @@ async def delayed_send(
115120
the delay and send the task after some delay.
116121
117122
:param scheduler: current scheduler.
123+
:param source: source of the task.
118124
:param task: task to send.
119125
:param delay: how long to wait.
120126
"""
121127
if delay > 0:
122128
await asyncio.sleep(delay)
123129
logger.info("Sending task %s.", task.task_name)
124-
await scheduler.on_ready(task)
130+
await scheduler.on_ready(source, task)
125131

126132

127133
async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
@@ -134,33 +140,34 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
134140
:param scheduler: current scheduler.
135141
"""
136142
loop = asyncio.get_event_loop()
137-
tasks: "List[ScheduledTask]" = []
143+
schedules: "Dict[ScheduleSource, List[ScheduledTask]]" = {}
138144

139145
current_task = asyncio.current_task()
140146
first_update_event = asyncio.Event()
141147
updater_task = loop.create_task(
142148
schedules_updater(
143149
scheduler,
144-
tasks,
150+
schedules,
145151
first_update_event,
146152
),
147153
)
148154
if current_task is not None:
149155
current_task.add_done_callback(lambda _: updater_task.cancel())
150156
await first_update_event.wait()
151157
while True:
152-
for task in tasks:
153-
try:
154-
task_delay = get_task_delay(task)
155-
except ValueError:
156-
logger.warning(
157-
"Cannot parse cron: %s for task: %s",
158-
task.cron,
159-
task.task_name,
160-
)
161-
continue
162-
if task_delay is not None:
163-
loop.create_task(delayed_send(scheduler, task, task_delay))
158+
for source, task_list in schedules.items():
159+
for task in task_list:
160+
try:
161+
task_delay = get_task_delay(task)
162+
except ValueError:
163+
logger.warning(
164+
"Cannot parse cron: %s for task: %s",
165+
task.cron,
166+
task.task_name,
167+
)
168+
continue
169+
if task_delay is not None:
170+
loop.create_task(delayed_send(scheduler, source, task, task_delay))
164171

165172
delay = (
166173
datetime.now().replace(second=1, microsecond=0)

taskiq/schedule_sources/label_based.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ async def get_schedules(self) -> List["ScheduledTask"]:
4444
cron=schedule.get("cron"),
4545
time=schedule.get("time"),
4646
cron_offset=schedule.get("cron_offset"),
47-
source=self,
4847
),
4948
)
5049
return schedules

taskiq/scheduler/merge_functions.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
from typing import TYPE_CHECKING, List
23

34
if TYPE_CHECKING: # pragma: no cover
@@ -34,8 +35,22 @@ def only_unique(
3435
:param new_tasks: newly discovered tasks.
3536
:return: list of unique schedules.
3637
"""
37-
result = old_tasks
38+
result = copy.copy(old_tasks)
3839
for task in new_tasks:
3940
if task not in result:
4041
result.append(task)
4142
return result
43+
44+
45+
def only_new(
46+
_old_tasks: List["ScheduledTask"],
47+
new_tasks: List["ScheduledTask"],
48+
) -> List["ScheduledTask"]:
49+
"""
50+
This function preserves only new schedules.
51+
52+
:param old_tasks: previously discovered tasks.
53+
:param new_tasks: newly discovered schedules.
54+
:return: list of new schedules.
55+
"""
56+
return new_tasks

taskiq/scheduler/scheduler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from taskiq.abc.broker import AsyncBroker
66
from taskiq.kicker import AsyncKicker
7-
from taskiq.scheduler.merge_functions import preserve_all
7+
from taskiq.scheduler.merge_functions import only_new
88
from taskiq.utils import maybe_awaitable
99

1010
if TYPE_CHECKING: # pragma: no cover
@@ -19,7 +19,6 @@ class ScheduledTask:
1919
labels: Dict[str, Any]
2020
args: List[Any]
2121
kwargs: Dict[str, Any]
22-
source: "ScheduleSource" # Backward point to source which this task belongs to
2322
cron: Optional[str] = field(default=None)
2423
cron_offset: Optional[Union[str, timedelta]] = field(default=None)
2524
time: Optional[datetime] = field(default=None)
@@ -44,7 +43,7 @@ def __init__(
4443
merge_func: Callable[
4544
[List["ScheduledTask"], List["ScheduledTask"]],
4645
List["ScheduledTask"],
47-
] = preserve_all,
46+
] = only_new,
4847
refresh_delay: float = 30.0,
4948
) -> None: # pragma: no cover
5049
self.broker = broker
@@ -61,19 +60,19 @@ async def startup(self) -> None: # pragma: no cover
6160
"""
6261
await self.broker.startup()
6362

64-
async def on_ready(self, task: ScheduledTask) -> None:
63+
async def on_ready(self, source: "ScheduleSource", task: ScheduledTask) -> None:
6564
"""
6665
This method is called when task is ready to be enqueued.
6766
6867
It's triggered on proper time depending on `task.cron` or `task.time` attribute.
6968
:param task: task to send
7069
"""
71-
await maybe_awaitable(task.source.pre_send(task))
70+
await maybe_awaitable(source.pre_send(task))
7271
await AsyncKicker(task.task_name, self.broker, task.labels).kiq(
7372
*task.args,
7473
**task.kwargs,
7574
)
76-
await maybe_awaitable(task.source.post_send(task))
75+
await maybe_awaitable(source.post_send(task))
7776

7877
async def shutdown(self) -> None:
7978
"""Shutdown the scheduler process."""

tests/cli/scheduler/test_task_delays.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
from tzlocal import get_localzone
66

77
from taskiq.cli.scheduler.run import get_task_delay
8-
from taskiq.schedule_sources.label_based import LabelScheduleSource
98
from taskiq.scheduler.scheduler import ScheduledTask
109

11-
DUMMY_SOURCE = LabelScheduleSource(broker=None) # type: ignore
12-
1310

1411
def test_should_run_success() -> None:
1512
hour = datetime.datetime.utcnow().hour
@@ -19,7 +16,6 @@ def test_should_run_success() -> None:
1916
labels={},
2017
args=[],
2118
kwargs={},
22-
source=DUMMY_SOURCE,
2319
cron=f"* {hour} * * *",
2420
),
2521
)
@@ -35,7 +31,6 @@ def test_should_run_cron_str_offset() -> None:
3531
labels={},
3632
args=[],
3733
kwargs={},
38-
source=DUMMY_SOURCE,
3934
cron=f"* {hour} * * *",
4035
cron_offset=str(zone),
4136
),
@@ -52,7 +47,6 @@ def test_should_run_cron_td_offset() -> None:
5247
labels={},
5348
args=[],
5449
kwargs={},
55-
source=DUMMY_SOURCE,
5650
cron=f"* {hour} * * *",
5751
cron_offset=datetime.timedelta(hours=offset),
5852
),
@@ -68,7 +62,6 @@ def test_time_utc_without_zone() -> None:
6862
labels={},
6963
args=[],
7064
kwargs={},
71-
source=DUMMY_SOURCE,
7265
time=time - datetime.timedelta(seconds=1),
7366
),
7467
)
@@ -83,7 +76,6 @@ def test_time_utc_with_zone() -> None:
8376
labels={},
8477
args=[],
8578
kwargs={},
86-
source=DUMMY_SOURCE,
8779
time=time - datetime.timedelta(seconds=1),
8880
),
8981
)
@@ -99,7 +91,6 @@ def test_time_utc_with_local_zone() -> None:
9991
labels={},
10092
args=[],
10193
kwargs={},
102-
source=DUMMY_SOURCE,
10394
time=time - datetime.timedelta(seconds=1),
10495
),
10596
)
@@ -114,7 +105,6 @@ def test_time_localtime_without_zone() -> None:
114105
labels={},
115106
args=[],
116107
kwargs={},
117-
source=DUMMY_SOURCE,
118108
time=time - datetime.timedelta(seconds=1),
119109
),
120110
)
@@ -130,7 +120,6 @@ def test_time_delay() -> None:
130120
labels={},
131121
args=[],
132122
kwargs={},
133-
source=DUMMY_SOURCE,
134123
time=time,
135124
),
136125
)

tests/schedule_sources/test_label_based.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def task() -> None:
3636
labels={"schedule": schedule_label},
3737
args=[],
3838
kwargs={},
39-
source=source,
4039
),
4140
]
4241

tests/scheduler/test_label_based_sched.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
3030
def task() -> None:
3131
pass
3232

33-
source = LabelScheduleSource(broker)
34-
schedules = await source.get_schedules()
33+
schedules = await LabelScheduleSource(broker).get_schedules()
3534
assert schedules == [
3635
ScheduledTask(
3736
cron=schedule_label[0].get("cron"),
@@ -40,7 +39,6 @@ def task() -> None:
4039
labels={"schedule": schedule_label},
4140
args=[],
4241
kwargs={},
43-
source=source,
4442
),
4543
]
4644

0 commit comments

Comments
 (0)