|
5 | 5 | import copy |
6 | 6 | import logging |
7 | 7 | import os |
8 | | -from concurrent.futures import ( |
9 | | - FIRST_COMPLETED, |
10 | | - ProcessPoolExecutor, |
11 | | - wait, |
12 | | -) |
13 | 8 | from dataclasses import dataclass |
14 | 9 | from typing import Callable, Dict, Optional, Union |
15 | 10 |
|
@@ -49,20 +44,16 @@ def _get_loader(self): |
49 | 44 | loader = "taskgraph.loader.default:loader" |
50 | 45 | return find_object(loader) |
51 | 46 |
|
52 | | - def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): |
53 | | - logger.debug(f"Loading tasks for kind {self.name}") |
54 | | - |
55 | | - parameters = Parameters(**parameters) |
| 47 | + def load_tasks(self, parameters, loaded_tasks, write_artifacts): |
56 | 48 | loader = self._get_loader() |
57 | 49 | config = copy.deepcopy(self.config) |
58 | 50 |
|
59 | | - inputs = loader( |
60 | | - self.name, |
61 | | - self.path, |
62 | | - config, |
63 | | - parameters, |
64 | | - list(kind_dependencies_tasks.values()), |
65 | | - ) |
| 51 | + kind_dependencies = config.get("kind-dependencies", []) |
| 52 | + kind_dependencies_tasks = { |
| 53 | + task.label: task for task in loaded_tasks if task.kind in kind_dependencies |
| 54 | + } |
| 55 | + |
| 56 | + inputs = loader(self.name, self.path, config, parameters, loaded_tasks) |
66 | 57 |
|
67 | 58 | transforms = TransformSequence() |
68 | 59 | for xform_path in config["transforms"]: |
@@ -96,7 +87,6 @@ def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): |
96 | 87 | ) |
97 | 88 | for task_dict in transforms(trans_config, inputs) |
98 | 89 | ] |
99 | | - logger.info(f"Generated {len(tasks)} tasks for kind {self.name}") |
100 | 90 | return tasks |
101 | 91 |
|
102 | 92 | @classmethod |
@@ -261,69 +251,6 @@ def _load_kinds(self, graph_config, target_kinds=None): |
261 | 251 | except KindNotFound: |
262 | 252 | continue |
263 | 253 |
|
264 | | - def _load_tasks(self, kinds, kind_graph, parameters): |
265 | | - all_tasks = {} |
266 | | - futures_to_kind = {} |
267 | | - futures = set() |
268 | | - edges = set(kind_graph.edges) |
269 | | - |
270 | | - with ProcessPoolExecutor() as executor: |
271 | | - |
272 | | - def submit_ready_kinds(): |
273 | | - """Create the next batch of tasks for kinds without dependencies.""" |
274 | | - nonlocal kinds, edges, futures |
275 | | - loaded_tasks = all_tasks.copy() |
276 | | - kinds_with_deps = {edge[0] for edge in edges} |
277 | | - ready_kinds = ( |
278 | | - set(kinds) - kinds_with_deps - set(futures_to_kind.values()) |
279 | | - ) |
280 | | - for name in ready_kinds: |
281 | | - kind = kinds.get(name) |
282 | | - if not kind: |
283 | | - message = ( |
284 | | - f'Could not find the kind "{name}"\nAvailable kinds:\n' |
285 | | - ) |
286 | | - for k in sorted(kinds): |
287 | | - message += f' - "{k}"\n' |
288 | | - raise Exception(message) |
289 | | - |
290 | | - future = executor.submit( |
291 | | - kind.load_tasks, |
292 | | - dict(parameters), |
293 | | - { |
294 | | - k: t |
295 | | - for k, t in loaded_tasks.items() |
296 | | - if t.kind in kind.config.get("kind-dependencies", []) |
297 | | - }, |
298 | | - self._write_artifacts, |
299 | | - ) |
300 | | - futures.add(future) |
301 | | - futures_to_kind[future] = name |
302 | | - |
303 | | - submit_ready_kinds() |
304 | | - while futures: |
305 | | - done, _ = wait(futures, return_when=FIRST_COMPLETED) |
306 | | - for future in done: |
307 | | - if exc := future.exception(): |
308 | | - executor.shutdown(wait=False, cancel_futures=True) |
309 | | - raise exc |
310 | | - kind = futures_to_kind.pop(future) |
311 | | - futures.remove(future) |
312 | | - |
313 | | - for task in future.result(): |
314 | | - if task.label in all_tasks: |
315 | | - raise Exception("duplicate tasks with label " + task.label) |
316 | | - all_tasks[task.label] = task |
317 | | - |
318 | | - # Update state for next batch of futures. |
319 | | - del kinds[kind] |
320 | | - edges = {e for e in edges if e[1] != kind} |
321 | | - |
322 | | - # Submit any newly unblocked kinds |
323 | | - submit_ready_kinds() |
324 | | - |
325 | | - return all_tasks |
326 | | - |
327 | 254 | def _run(self): |
328 | 255 | logger.info("Loading graph configuration.") |
329 | 256 | graph_config = load_graph_config(self.root_dir) |
@@ -378,8 +305,31 @@ def _run(self): |
378 | 305 | ) |
379 | 306 |
|
380 | 307 | logger.info("Generating full task set") |
381 | | - all_tasks = self._load_tasks(kinds, kind_graph, parameters) |
| 308 | + all_tasks = {} |
| 309 | + for kind_name in kind_graph.visit_postorder(): |
| 310 | + logger.debug(f"Loading tasks for kind {kind_name}") |
| 311 | + |
| 312 | + kind = kinds.get(kind_name) |
| 313 | + if not kind: |
| 314 | + message = f'Could not find the kind "{kind_name}"\nAvailable kinds:\n' |
| 315 | + for k in sorted(kinds): |
| 316 | + message += f' - "{k}"\n' |
| 317 | + raise Exception(message) |
382 | 318 |
|
| 319 | + try: |
| 320 | + new_tasks = kind.load_tasks( |
| 321 | + parameters, |
| 322 | + list(all_tasks.values()), |
| 323 | + self._write_artifacts, |
| 324 | + ) |
| 325 | + except Exception: |
| 326 | + logger.exception(f"Error loading tasks for kind {kind_name}:") |
| 327 | + raise |
| 328 | + for task in new_tasks: |
| 329 | + if task.label in all_tasks: |
| 330 | + raise Exception("duplicate tasks with label " + task.label) |
| 331 | + all_tasks[task.label] = task |
| 332 | + logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}") |
383 | 333 | full_task_set = TaskGraph(all_tasks, Graph(frozenset(all_tasks), frozenset())) |
384 | 334 | yield self.verify("full_task_set", full_task_set, graph_config, parameters) |
385 | 335 |
|
|
0 commit comments