Skip to content

Commit 47d1fbd

Browse files
committed
Added second-based scheduler intervals.
1 parent 23422c7 commit 47d1fbd

File tree

6 files changed

+137
-67
lines changed

6 files changed

+137
-67
lines changed

taskiq/cli/scheduler/args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class SchedulerArgs:
1717
fs_discover: bool = False
1818
tasks_pattern: Sequence[str] = ("**/tasks.py",)
1919
skip_first_run: bool = False
20+
update_interval: Optional[int] = None
2021

2122
@classmethod
2223
def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs":
@@ -80,6 +81,15 @@ def from_cli(cls, args: Optional[Sequence[str]] = None) -> "SchedulerArgs":
8081
"This option skips running tasks immediately after scheduler start."
8182
),
8283
)
84+
parser.add_argument(
85+
"--update-interval",
86+
type=int,
87+
default=None,
88+
help=(
89+
"Interval in seconds to check for new tasks. "
90+
"If not specified, scheduler will run once a minute."
91+
),
92+
)
8393

8494
namespace = parser.parse_args(args)
8595
# If there are any patterns specified, remove default.

taskiq/cli/scheduler/run.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
from datetime import datetime, timedelta
55
from logging import basicConfig, getLevelName, getLogger
6-
from typing import Dict, List, Optional
6+
from typing import Any, Dict, List, Optional, Set
77

88
import pytz
99
from pycron import is_now
@@ -98,12 +98,10 @@ def get_task_delay(task: ScheduledTask) -> Optional[int]:
9898
task_time = to_tz_aware(task.time)
9999
if task_time <= now:
100100
return 0
101-
one_min_ahead = (now + timedelta(minutes=1)).replace(second=1, microsecond=0)
102-
if task_time <= one_min_ahead:
103-
delay = task_time - now
104-
if delay.microseconds:
105-
return int(delay.total_seconds()) + 1
106-
return int(delay.total_seconds())
101+
delay = task_time - now
102+
if delay.microseconds:
103+
return int(delay.total_seconds()) + 1
104+
return int(delay.total_seconds())
107105
return None
108106

109107

@@ -145,7 +143,10 @@ async def delayed_send(
145143
await scheduler.on_ready(source, task)
146144

147145

148-
async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
146+
async def run_scheduler_loop( # noqa: C901
147+
scheduler: TaskiqScheduler,
148+
interval: Optional[timedelta] = None,
149+
) -> None:
149150
"""
150151
Runs scheduler loop.
151152
@@ -155,9 +156,18 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
155156
:param scheduler: current scheduler.
156157
"""
157158
loop = asyncio.get_event_loop()
158-
running_schedules = set()
159+
running_schedules: Dict[str, asyncio.Task[Any]] = {}
160+
ran_cron_jobs: Set[str] = set()
161+
current_minute = datetime.now(tz=pytz.UTC).minute
159162
while True:
160-
# We use this method to correctly sleep for one minute.
163+
now = datetime.now(tz=pytz.UTC)
164+
if now.minute != current_minute:
165+
current_minute = now.minute
166+
ran_cron_jobs.clear()
167+
if interval is not None:
168+
next_run = now + interval
169+
else:
170+
next_run = (now + timedelta(minutes=1)).replace(second=1, microsecond=0)
161171
scheduled_tasks = await get_all_schedules(scheduler)
162172
for source, task_list in scheduled_tasks.items():
163173
logger.debug("Got %d schedules from source %s.", len(task_list), source)
@@ -172,16 +182,37 @@ async def run_scheduler_loop(scheduler: TaskiqScheduler) -> None:
172182
task.schedule_id,
173183
)
174184
continue
175-
if task_delay is not None:
176-
send_task = loop.create_task(
177-
delayed_send(scheduler, source, task, task_delay),
178-
)
179-
running_schedules.add(send_task)
180-
send_task.add_done_callback(running_schedules.discard)
181-
next_minute = datetime.now().replace(second=0, microsecond=0) + timedelta(
182-
minutes=1,
183-
)
184-
delay = next_minute - datetime.now()
185+
# If task delay is None, we don't need to run it.
186+
if task_delay is None:
187+
continue
188+
# If task is delayed for more than next_run,
189+
# we don't need to run it, because we will
190+
# run it in the next iteration.
191+
if now + timedelta(seconds=task_delay) >= next_run:
192+
continue
193+
# If task is already running, we don't need to run it again.
194+
if task.schedule_id in running_schedules and task_delay < 1:
195+
continue
196+
# If task is cron job, we need to check if
197+
# we already ran it this minute.
198+
if task.cron is not None:
199+
if task.schedule_id in ran_cron_jobs:
200+
continue
201+
ran_cron_jobs.add(task.schedule_id)
202+
send_task = loop.create_task(
203+
delayed_send(scheduler, source, task, task_delay),
204+
# We need to set the name of the task
205+
# to be able to discard its reference
206+
# after it is done.
207+
name=f"schedule_{task.schedule_id}",
208+
)
209+
running_schedules[task.schedule_id] = send_task
210+
send_task.add_done_callback(
211+
lambda task_future: running_schedules.pop(
212+
task_future.get_name().removeprefix("schedule_"),
213+
),
214+
)
215+
delay = next_run - datetime.now(tz=pytz.UTC)
185216
logger.debug(
186217
"Sleeping for %.2f seconds before getting schedules.",
187218
delay.total_seconds(),
@@ -226,6 +257,10 @@ async def run_scheduler(args: SchedulerArgs) -> None:
226257
for source in scheduler.sources:
227258
await source.startup()
228259

260+
interval = None
261+
if args.update_interval:
262+
interval = timedelta(seconds=args.update_interval)
263+
229264
logger.info("Starting scheduler.")
230265
await scheduler.startup()
231266
logger.info("Startup completed.")
@@ -239,7 +274,7 @@ async def run_scheduler(args: SchedulerArgs) -> None:
239274
await asyncio.sleep(delay.total_seconds())
240275
logger.info("First run skipped. The scheduler is now running.")
241276
try:
242-
await run_scheduler_loop(scheduler)
277+
await run_scheduler_loop(scheduler, interval)
243278
except asyncio.CancelledError:
244279
logger.warning("Shutting down scheduler.")
245280
await scheduler.shutdown()
Lines changed: 46 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import uuid
12
from logging import getLogger
2-
from typing import List
3+
from typing import Dict, List
34

45
from taskiq.abc.broker import AsyncBroker
56
from taskiq.abc.schedule_source import ScheduleSource
@@ -13,20 +14,26 @@ class LabelScheduleSource(ScheduleSource):
1314

1415
def __init__(self, broker: AsyncBroker) -> None:
1516
self.broker = broker
17+
self.schedules: Dict[str, ScheduledTask] = {}
1618

17-
async def get_schedules(self) -> List["ScheduledTask"]:
19+
async def startup(self) -> None:
1820
"""
19-
Collect schedules for all tasks.
20-
21-
this function checks labels for all
22-
tasks available to the broker.
21+
Startup the schedule source.
2322
23+
This function iterates over all tasks
24+
available to the broker and collects
25+
schedules from their labels.
2426
If task has a schedule label,
25-
it will be parsed and returned.
27+
it will be parsed and added to the
28+
scheduler list.
2629
27-
:return: list of schedules.
30+
Every time schedule is added, the random
31+
schedule id is generated. Please be aware that
32+
they are different for every startup.
33+
34+
:return: None
2835
"""
29-
schedules = []
36+
self.schedules.clear()
3037
for task_name, task in self.broker.get_all_tasks().items():
3138
if task.broker != self.broker:
3239
# if task broker doesn't match self, something is probably wrong
@@ -40,20 +47,36 @@ async def get_schedules(self) -> List["ScheduledTask"]:
4047
continue
4148
labels = schedule.get("labels", {})
4249
labels.update(task.labels)
43-
schedules.append(
44-
ScheduledTask(
45-
task_name=task_name,
46-
labels=labels,
47-
args=schedule.get("args", []),
48-
kwargs=schedule.get("kwargs", {}),
49-
cron=schedule.get("cron"),
50-
time=schedule.get("time"),
51-
cron_offset=schedule.get("cron_offset"),
52-
),
50+
schedule_id = uuid.uuid4().hex
51+
52+
self.schedules[schedule_id] = ScheduledTask(
53+
task_name=task_name,
54+
labels=labels,
55+
schedule_id=schedule_id,
56+
args=schedule.get("args", []),
57+
kwargs=schedule.get("kwargs", {}),
58+
cron=schedule.get("cron"),
59+
time=schedule.get("time"),
60+
cron_offset=schedule.get("cron_offset"),
5361
)
54-
return schedules
5562

56-
def post_send(self, scheduled_task: ScheduledTask) -> None:
63+
return await super().startup()
64+
65+
async def get_schedules(self) -> List["ScheduledTask"]:
66+
"""
67+
Collect schedules for all tasks.
68+
69+
this function checks labels for all
70+
tasks available to the broker.
71+
72+
If task has a schedule label,
73+
it will be parsed and returned.
74+
75+
:return: list of schedules.
76+
"""
77+
return list(self.schedules.values())
78+
79+
def post_send(self, task: "ScheduledTask") -> None:
5780
"""
5881
Remove `time` schedule from task's scheduler list.
5982
@@ -62,22 +85,7 @@ def post_send(self, scheduled_task: ScheduledTask) -> None:
6285
6386
:param scheduled_task: task that just have sent
6487
"""
65-
if scheduled_task.cron or not scheduled_task.time:
88+
if task.cron or not task.time:
6689
return # it's scheduled task with cron label, do not remove this trigger.
6790

68-
for task_name, task in self.broker.get_all_tasks().items():
69-
if task.broker != self.broker:
70-
# if task broker doesn't match self, something is probably wrong
71-
logger.warning(
72-
f"Broker for {task_name} `{task.broker}` doesn't "
73-
f"match scheduler's broker `{self.broker}`",
74-
)
75-
continue
76-
if scheduled_task.task_name != task_name:
77-
continue
78-
79-
schedule_list = task.labels.get("schedule", []).copy()
80-
for idx, schedule in enumerate(schedule_list):
81-
if schedule.get("time") == scheduled_task.time:
82-
task.labels.get("schedule", []).pop(idx)
83-
return
91+
self.schedules.pop(task.schedule_id, None)

tests/cli/scheduler/test_task_delays.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
def test_should_run_success() -> None:
12-
hour = datetime.datetime.utcnow().hour
12+
hour = datetime.datetime.now(datetime.timezone.utc).hour
1313
delay = get_task_delay(
1414
ScheduledTask(
1515
task_name="",
@@ -97,18 +97,26 @@ def test_time_utc_with_local_zone() -> None:
9797
assert delay is not None and delay >= 0
9898

9999

100+
@freeze_time("2023-01-14 12:00:00")
100101
def test_time_localtime_without_zone() -> None:
101102
time = datetime.datetime.now(tz=pytz.FixedOffset(240)).replace(tzinfo=None)
103+
time_to_run = time - datetime.timedelta(seconds=1)
104+
102105
delay = get_task_delay(
103106
ScheduledTask(
104107
task_name="",
105108
labels={},
106109
args=[],
107110
kwargs={},
108-
time=time - datetime.timedelta(seconds=1),
111+
time=time_to_run,
109112
),
110113
)
111-
assert delay is None
114+
115+
expected_delay = time_to_run.replace(tzinfo=pytz.UTC) - datetime.datetime.now(
116+
pytz.UTC,
117+
)
118+
119+
assert delay == int(expected_delay.total_seconds())
112120

113121

114122
@freeze_time("2023-01-14 12:00:00")

tests/schedule_sources/test_label_based.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, List
33

44
import pytest
5+
import pytz
56

67
from taskiq.brokers.inmemory_broker import InMemoryBroker
78
from taskiq.schedule_sources.label_based import LabelScheduleSource
@@ -13,7 +14,7 @@
1314
"schedule_label",
1415
[
1516
pytest.param([{"cron": "* * * * *"}], id="cron"),
16-
pytest.param([{"time": datetime.utcnow()}], id="time"),
17+
pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"),
1718
],
1819
)
1920
async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
@@ -27,6 +28,7 @@ def task() -> None:
2728
pass
2829

2930
source = LabelScheduleSource(broker)
31+
await source.startup()
3032
schedules = await source.get_schedules()
3133
assert schedules == [
3234
ScheduledTask(
@@ -53,5 +55,6 @@ def task() -> None:
5355
pass
5456

5557
source = LabelScheduleSource(broker)
58+
await source.startup()
5659
schedules = await source.get_schedules()
5760
assert schedules == []

tests/scheduler/test_label_based_sched.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List
44

55
import pytest
6+
import pytz
67
from freezegun import freeze_time
78

89
from taskiq.brokers.inmemory_broker import InMemoryBroker
@@ -18,7 +19,7 @@
1819
"schedule_label",
1920
[
2021
pytest.param([{"cron": "* * * * *"}], id="cron"),
21-
pytest.param([{"time": datetime.utcnow()}], id="time"),
22+
pytest.param([{"time": datetime.now(pytz.UTC)}], id="time"),
2223
],
2324
)
2425
async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
@@ -31,7 +32,9 @@ async def test_label_discovery(schedule_label: List[Dict[str, Any]]) -> None:
3132
def task() -> None:
3233
pass
3334

34-
schedules = await LabelScheduleSource(broker).get_schedules()
35+
source = LabelScheduleSource(broker)
36+
await source.startup()
37+
schedules = await source.get_schedules()
3538
assert schedules == [
3639
ScheduledTask(
3740
schedule_id=schedules[0].schedule_id,
@@ -57,6 +60,7 @@ def task() -> None:
5760
pass
5861

5962
source = LabelScheduleSource(broker)
63+
await source.startup()
6064
schedules = await source.get_schedules()
6165
assert schedules == []
6266

@@ -69,6 +73,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None:
6973
broker=broker,
7074
sources=[LabelScheduleSource(broker)],
7175
)
76+
for source in scheduler.sources:
77+
await source.startup()
7278

7379
# NOTE:
7480
# freeze time to 00:00, so task won't be scheduled by `cron`, only by `time`
@@ -77,8 +83,8 @@ async def test_task_scheduled_at_time_runs_only_once(mock_sleep: None) -> None:
7783
@broker.task(
7884
task_name="test_task",
7985
schedule=[
80-
{"time": datetime.utcnow(), "args": [1]},
81-
{"time": datetime.utcnow() + timedelta(days=1), "args": [2]},
86+
{"time": datetime.now(pytz.UTC), "args": [1]},
87+
{"time": datetime.now(pytz.UTC) + timedelta(days=1), "args": [2]},
8288
{"cron": "1 * * * *", "args": [3]},
8389
],
8490
)

0 commit comments

Comments
 (0)