Skip to content

Commit 5207917

Browse files
Alphareahal
authored andcommitted
feat: use new batched queries in task replacement
Batching queries to get from task index to status reduces the number of (sometimes trans-continental) queries from 2*900+ to ~2. This reduces the time spent replacing tasks from 20% to 75% depending on the use-case. The 20% improvement in wall time was observed when running `mach taskgraph morphed` in a CI worker, while the 75% improvement was observed in a developer machine in France running `mach taskgraph full`. More information in taskcluster/taskcluster-rfcs#189.
1 parent ac26282 commit 5207917

File tree

5 files changed

+153
-39
lines changed

5 files changed

+153
-39
lines changed

src/taskgraph/optimize/base.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from taskgraph.taskgraph import TaskGraph
2323
from taskgraph.util.parameterization import resolve_task_references, resolve_timestamps
2424
from taskgraph.util.python_path import import_sibling_modules
25+
from taskgraph.util.taskcluster import find_task_id_batched, status_task_batched
2526

2627
logger = logging.getLogger(__name__)
2728
registry = {}
@@ -51,6 +52,9 @@ def optimize_task_graph(
5152
Perform task optimization, returning a taskgraph and a map from label to
5253
assigned taskId, including replacement tasks.
5354
"""
55+
# avoid circular import
56+
from taskgraph.optimize.strategies import IndexSearch
57+
5458
label_to_taskid = {}
5559
if not existing_tasks:
5660
existing_tasks = {}
@@ -70,6 +74,23 @@ def optimize_task_graph(
7074
do_not_optimize=do_not_optimize,
7175
)
7276

77+
# Gather each relevant task's index
78+
indexes = set()
79+
for label in target_task_graph.graph.visit_postorder():
80+
if label in do_not_optimize:
81+
continue
82+
_, strategy, arg = optimizations(label)
83+
if isinstance(strategy, IndexSearch) and arg is not None:
84+
indexes.update(arg)
85+
86+
index_to_taskid = {}
87+
taskid_to_status = {}
88+
if indexes:
89+
# Find their respective status using TC index/queue batch APIs
90+
indexes = list(indexes)
91+
index_to_taskid = find_task_id_batched(indexes)
92+
taskid_to_status = status_task_batched(list(index_to_taskid.values()))
93+
7394
replaced_tasks = replace_tasks(
7495
target_task_graph=target_task_graph,
7596
optimizations=optimizations,
@@ -78,6 +99,8 @@ def optimize_task_graph(
7899
label_to_taskid=label_to_taskid,
79100
existing_tasks=existing_tasks,
80101
removed_tasks=removed_tasks,
102+
index_to_taskid=index_to_taskid,
103+
taskid_to_status=taskid_to_status,
81104
)
82105

83106
return (
@@ -259,12 +282,17 @@ def replace_tasks(
259282
label_to_taskid,
260283
removed_tasks,
261284
existing_tasks,
285+
index_to_taskid,
286+
taskid_to_status,
262287
):
263288
"""
264289
Implement the "Replacing Tasks" phase, returning a set of task labels of
265290
all replaced tasks. The replacement taskIds are added to label_to_taskid as
266291
a side-effect.
267292
"""
293+
# avoid circular import
294+
from taskgraph.optimize.strategies import IndexSearch
295+
268296
opt_counts = defaultdict(int)
269297
replaced = set()
270298
dependents_of = target_task_graph.graph.reverse_links_dict()
@@ -307,6 +335,10 @@ def replace_tasks(
307335
deadline = max(
308336
resolve_timestamps(now, task.task["deadline"]) for task in dependents
309337
)
338+
339+
if isinstance(opt, IndexSearch):
340+
arg = arg, index_to_taskid, taskid_to_status
341+
310342
repl = opt.should_replace_task(task, params, deadline, arg)
311343
if repl:
312344
if repl is True:

src/taskgraph/optimize/strategies.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from taskgraph.optimize.base import OptimizationStrategy, register_strategy
55
from taskgraph.util.path import match as match_path
6-
from taskgraph.util.taskcluster import find_task_id, status_task
76

87
logger = logging.getLogger(__name__)
98

@@ -22,12 +21,14 @@ class IndexSearch(OptimizationStrategy):
2221

2322
fmt = "%Y-%m-%dT%H:%M:%S.%fZ"
2423

25-
def should_replace_task(self, task, params, deadline, index_paths):
24+
def should_replace_task(self, task, params, deadline, arg):
2625
"Look for a task with one of the given index paths"
26+
index_paths, label_to_taskid, taskid_to_status = arg
27+
2728
for index_path in index_paths:
2829
try:
29-
task_id = find_task_id(index_path)
30-
status = status_task(task_id)
30+
task_id = label_to_taskid[index_path]
31+
status = taskid_to_status[task_id]
3132
# status can be `None` if we're in `testing` mode
3233
# (e.g. test-action-callback)
3334
if not status or status.get("state") in ("exception", "failed"):
@@ -40,7 +41,7 @@ def should_replace_task(self, task, params, deadline, index_paths):
4041

4142
return task_id
4243
except KeyError:
43-
# 404 will end up here and go on to the next index path
44+
# go on to the next index path
4445
pass
4546

4647
return False

src/taskgraph/util/taskcluster.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,48 @@ def find_task_id(index_path, use_proxy=False):
193193
return response.json()["taskId"]
194194

195195

196+
def find_task_id_batched(index_paths, use_proxy=False):
197+
"""Gets the task id of multiple tasks given their respective index.
198+
199+
Args:
200+
index_paths (List[str]): A list of task indexes.
201+
use_proxy (bool): Whether to use taskcluster-proxy (default: False)
202+
203+
Returns:
204+
Dict[str, str]: A dictionary object mapping each valid index path
205+
to its respective task id.
206+
207+
See the endpoint here:
208+
https://docs.taskcluster.net/docs/reference/core/index/api#findTasksAtIndex
209+
"""
210+
endpoint = liburls.api(get_root_url(use_proxy), "index", "v1", "tasks/indexes")
211+
task_ids = {}
212+
continuation_token = None
213+
214+
while True:
215+
response = _do_request(
216+
endpoint,
217+
json={
218+
"indexes": index_paths,
219+
},
220+
params={"continuationToken": continuation_token},
221+
)
222+
223+
response_data = response.json()
224+
if not response_data["tasks"]:
225+
break
226+
response_tasks = response_data["tasks"]
227+
if (len(task_ids) + len(response_tasks)) > len(index_paths):
228+
# Sanity check
229+
raise ValueError("more task ids were returned than were asked for")
230+
task_ids.update((t["namespace"], t["taskId"]) for t in response_tasks)
231+
232+
continuationToken = response_data.get("continuationToken")
233+
if continuationToken is None:
234+
break
235+
return task_ids
236+
237+
196238
def get_artifact_from_index(index_path, artifact_path, use_proxy=False):
197239
full_path = index_path + "/artifacts/" + artifact_path
198240
response = _do_request(get_index_url(full_path, use_proxy))
@@ -271,6 +313,49 @@ def status_task(task_id, use_proxy=False):
271313
return status
272314

273315

316+
def status_task_batched(task_ids, use_proxy=False):
317+
"""Gets the status of multiple tasks given task_ids.
318+
319+
In testing mode, just logs that it would have retrieved statuses.
320+
321+
Args:
322+
task_id (List[str]): A list of task ids.
323+
use_proxy (bool): Whether to use taskcluster-proxy (default: False)
324+
325+
Returns:
326+
dict: A dictionary object as defined here:
327+
https://docs.taskcluster.net/docs/reference/platform/queue/api#statuses
328+
"""
329+
if testing:
330+
logger.info(f"Would have gotten status for {len(task_ids)} tasks.")
331+
return
332+
endpoint = liburls.api(get_root_url(use_proxy), "queue", "v1", "tasks/status")
333+
statuses = {}
334+
continuation_token = None
335+
336+
while True:
337+
response = _do_request(
338+
endpoint,
339+
json={
340+
"taskIds": task_ids,
341+
},
342+
params={
343+
"continuationToken": continuation_token,
344+
},
345+
)
346+
response_data = response.json()
347+
if not response_data["statuses"]:
348+
break
349+
response_tasks = response_data["statuses"]
350+
if (len(statuses) + len(response_tasks)) > len(task_ids):
351+
raise ValueError("more task statuses were returned than were asked for")
352+
statuses.update((t["taskId"], t["status"]) for t in response_tasks)
353+
continuationToken = response_data.get("continuationToken")
354+
if continuationToken is None:
355+
break
356+
return statuses
357+
358+
274359
def state_task(task_id, use_proxy=False):
275360
"""Gets the state of a task given a task_id.
276361

test/test_optimize.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,20 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
269269

270270

271271
@pytest.mark.parametrize(
272-
"graph,kwargs,exp_replaced,exp_removed,exp_label_to_taskid",
272+
"graph,kwargs,exp_replaced,exp_removed",
273273
(
274274
# A task cannot be replaced if it depends on one that was not replaced
275275
pytest.param(
276276
make_triangle(
277277
t1={"replace": "e1"},
278278
t3={"replace": "e3"},
279279
),
280-
{},
280+
{
281+
"index_to_taskid": {"t1": "e1"},
282+
},
281283
# expectations
282284
{"t1"},
283285
set(),
284-
{"t1": "e1"},
285286
id="blocked",
286287
),
287288
# A task cannot be replaced if it should not be optimized
@@ -291,11 +292,13 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
291292
t2={"replace": "xxx"}, # but do_not_optimize
292293
t3={"replace": "e3"},
293294
),
294-
{"do_not_optimize": {"t2"}},
295+
{
296+
"do_not_optimize": {"t2"},
297+
"index_to_taskid": {"t1": "e1"},
298+
},
295299
# expectations
296300
{"t1"},
297301
set(),
298-
{"t1": "e1"},
299302
id="do_not_optimize",
300303
),
301304
# No tasks are replaced when strategy is 'never'
@@ -305,7 +308,6 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
305308
# expectations
306309
set(),
307310
set(),
308-
{},
309311
id="never",
310312
),
311313
# All replaceable tasks are replaced when strategy is 'replace'
@@ -315,11 +317,12 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
315317
t2={"replace": "e2"},
316318
t3={"replace": "e3"},
317319
),
318-
{},
320+
{
321+
"index_to_taskid": {"t1": "e1", "t2": "e2", "t3": "e3"},
322+
},
319323
# expectations
320324
{"t1", "t2", "t3"},
321325
set(),
322-
{"t1": "e1", "t2": "e2", "t3": "e3"},
323326
id="all",
324327
),
325328
# A task can be replaced with nothing
@@ -329,11 +332,12 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
329332
t2={"replace": True},
330333
t3={"replace": True},
331334
),
332-
{},
335+
{
336+
"index_to_taskid": {"t1": "e1"},
337+
},
333338
# expectations
334339
{"t1"},
335340
{"t2", "t3"},
336-
{"t1": "e1"},
337341
id="tasks_removed",
338342
),
339343
# A task which expires before a dependents deadline is not a valid replacement.
@@ -353,7 +357,6 @@ def test_remove_tasks(monkeypatch, graph, kwargs, exp_removed):
353357
# expectations
354358
set(),
355359
set(),
356-
{},
357360
id="deadline",
358361
),
359362
),
@@ -363,7 +366,6 @@ def test_replace_tasks(
363366
kwargs,
364367
exp_replaced,
365368
exp_removed,
366-
exp_label_to_taskid,
367369
):
368370
"""Tests the `replace_tasks` function.
369371
@@ -378,6 +380,8 @@ def test_replace_tasks(
378380
kwargs.setdefault("params", {})
379381
kwargs.setdefault("do_not_optimize", set())
380382
kwargs.setdefault("label_to_taskid", {})
383+
kwargs.setdefault("index_to_taskid", {})
384+
kwargs.setdefault("taskid_to_status", {})
381385
kwargs.setdefault("removed_tasks", set())
382386
kwargs.setdefault("existing_tasks", {})
383387

@@ -388,7 +392,6 @@ def test_replace_tasks(
388392
)
389393
assert got_replaced == exp_replaced
390394
assert kwargs["removed_tasks"] == exp_removed
391-
assert kwargs["label_to_taskid"] == exp_label_to_taskid
392395

393396

394397
@pytest.mark.parametrize(

test/test_optimize_strategies.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Any copyright is dedicated to the public domain.
22
# http://creativecommons.org/publicdomain/zero/1.0/
33

4-
import os
54
from datetime import datetime
65
from test.fixtures.gen import make_task
76
from time import mktime
@@ -44,31 +43,25 @@ def params():
4443
),
4544
),
4645
)
47-
def test_index_search(responses, params, state, expires, expected):
46+
def test_index_search(state, expires, expected):
4847
taskid = "abc"
4948
index_path = "foo.bar.latest"
50-
responses.add(
51-
responses.GET,
52-
f"{os.environ['TASKCLUSTER_ROOT_URL']}/api/index/v1/task/{index_path}",
53-
json={"taskId": taskid},
54-
status=200,
55-
)
56-
57-
responses.add(
58-
responses.GET,
59-
f"{os.environ['TASKCLUSTER_ROOT_URL']}/api/queue/v1/task/{taskid}/status",
60-
json={
61-
"status": {
62-
"state": state,
63-
"expires": expires,
64-
}
65-
},
66-
status=200,
67-
)
49+
label_to_taskid = {index_path: taskid}
50+
taskid_to_status = {
51+
taskid: {
52+
"state": state,
53+
"expires": expires,
54+
}
55+
}
6856

6957
opt = IndexSearch()
7058
deadline = "2021-06-07T19:03:20.482Z"
71-
assert opt.should_replace_task({}, params, deadline, (index_path,)) == expected
59+
assert (
60+
opt.should_replace_task(
61+
{}, params, deadline, ((index_path,), label_to_taskid, taskid_to_status)
62+
)
63+
== expected
64+
)
7265

7366

7467
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)