Skip to content

Commit 023af7c

Browse files
authored
Run in parallel via CLI (#273)
- Switch from `tqdm` to `rich` for progress tracking - Set up `pytest-xdist` to distribute tests over multiple CPU cores - Reworked CLI to parallelize morphing over specified number of workers when there is more than one job and more than one worker requested. Passing `0` will use all available cores. - Updated docs and test suite for new additions - Set matplotlib backend to Agg in actions to hopefully fix tkinter issue on Windows runners
1 parent ace8539 commit 023af7c

File tree

11 files changed

+330
-85
lines changed

11 files changed

+330
-85
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ jobs:
4646
os: [macos-latest, ubuntu-latest, windows-latest]
4747
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
4848

49+
env:
50+
MPLBACKEND: Agg # non-interactive backend for matplotlib
51+
4952
steps:
5053
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
5154

.github/workflows/generate-morphs.yml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ jobs:
6464
echo "Detected changes to dataset(s): $DATASET_ALL_CHANGED_FILES"
6565
DATASET_ARGS=$(python bin/ci.py $DATASET_ALL_CHANGED_FILES)
6666
echo "Generating morphs for the following datasets: $DATASET_ARGS"
67-
parallel -j0 data-morph \
67+
data-morph \
6868
--start-shape $DATASET_ARGS \
69-
--target-shape {} \
70-
::: bullseye heart rectangle star slant_up
69+
--target-shape bullseye heart rectangle star slant_up \
70+
--workers 0
7171
7272
# If shapes are added or modified in this PR
7373
- name: Generate morphs from new or changed shapes
@@ -78,20 +78,20 @@ jobs:
7878
echo "Detected changes to shape(s): $SHAPE_ALL_CHANGED_FILES"
7979
SHAPE_ARGS=$(python bin/ci.py $SHAPE_ALL_CHANGED_FILES)
8080
echo "Generating morphs for the following shapes: $SHAPE_ARGS"
81-
parallel -j0 data-morph \
81+
data-morph \
8282
--start-shape music \
83-
--target-shape {} \
84-
::: $SHAPE_ARGS
83+
--target-shape $SHAPE_ARGS \
84+
--workers 0
8585
8686
# For core code changes, we want to do a couple morphs to see if they still look ok
8787
# Only need to run if neither of the previous two morphs ran
8888
- name: Morph shapes with core code changes
8989
if: steps.changed-files-yaml.outputs.dataset_any_changed != 'true' && steps.changed-files-yaml.outputs.shape_any_changed != 'true'
9090
run: |
91-
parallel -j0 data-morph \
91+
data-morph \
9292
--start-shape music \
93-
--target-shape {} \
94-
::: bullseye heart rectangle star slant_up
93+
--target-shape bullseye heart star \
94+
--workers 0
9595
9696
- uses: actions/upload-artifact@65c4c4a1ddee5b72f698fdd19549f0f0fb45cf08 # v4.6.0
9797
with:

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ API
1313
data_morph.data
1414
data_morph.morpher
1515
data_morph.plotting
16+
data_morph.progress
1617
data_morph.shapes
1718

1819
----

docs/cli.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ Examples
1919
2020
$ data-morph --start-shape panda --target-shape star
2121
22-
2. Morph the panda shape into all available target shapes:
22+
2. Morph the panda shape into all available target shapes distributing the work
23+
to as many worker processes as possible:
2324

2425
.. code-block:: console
2526
26-
$ data-morph --start-shape panda --target-shape all
27+
$ data-morph --start-shape panda --target-shape all --workers 0
2728
2829
3. Morph the cat, dog, and panda shapes into the circle and slant_down shapes:
2930

docs/conf.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@
5959
# -- intersphinx -------------------------------------------------------------
6060

6161
intersphinx_mapping = {
62-
'matplotlib': ('https://matplotlib.org/stable/', None),
63-
'numpy': ('https://numpy.org/doc/stable/', None),
64-
'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None),
65-
'Pillow': ('https://pillow.readthedocs.io/en/stable/', None),
66-
'pytest': ('https://pytest.org/en/stable/', None),
67-
'python': ('https://docs.python.org/3/', None),
62+
'matplotlib': ('https://matplotlib.org/stable', None),
63+
'numpy': ('https://numpy.org/doc/stable', None),
64+
'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None),
65+
'Pillow': ('https://pillow.readthedocs.io/en/stable', None),
66+
'pytest': ('https://pytest.org/en/stable', None),
67+
'python': ('https://docs.python.org/3', None),
68+
'rich': ('https://rich.readthedocs.io/en/stable', None),
6869
}
6970

7071

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,15 @@ dependencies = [
4848
"matplotlib>=3.3",
4949
"numpy>=1.20",
5050
"pandas>=1.2",
51-
"tqdm>=4.64.1",
51+
"rich>=13.9.4",
5252
]
5353
optional-dependencies.dev = [
5454
"pre-commit",
5555
"pytest",
5656
"pytest-cov",
5757
"pytest-mock",
5858
"pytest-randomly",
59+
"pytest-xdist",
5960
]
6061
optional-dependencies.docs = [
6162
"pydata-sphinx-theme>=0.15.3",
@@ -140,6 +141,7 @@ addopts = [
140141
"-ra",
141142
"-l",
142143
"-v",
144+
"-n=auto", # use as many workers as possible with pytest-xdist
143145
"--tb=short",
144146
"--import-mode=importlib",
145147
"--strict-markers",

src/data_morph/cli.py

Lines changed: 188 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,30 @@
33
from __future__ import annotations
44

55
import argparse
6-
import sys
6+
import itertools
7+
import multiprocessing
8+
from concurrent.futures import ProcessPoolExecutor
9+
from pathlib import Path
710
from typing import TYPE_CHECKING
811

912
from . import __version__
1013
from .data.loader import DataLoader
1114
from .morpher import DataMorpher
15+
from .progress import DataMorphProgress
1216
from .shapes.factory import ShapeFactory
1317

1418
if TYPE_CHECKING:
1519
from collections.abc import Sequence
1620

21+
from rich.progress import TaskID
22+
1723
ARG_DEFAULTS = {
1824
'output_dir': 'morphed_data',
1925
'decimals': 2,
2026
'min_shake': 0.3,
2127
'iterations': 100_000,
2228
'freeze': 0,
29+
'workers': 2,
2330
}
2431

2532

@@ -50,6 +57,16 @@ def generate_parser() -> argparse.ArgumentParser:
5057
parser.add_argument(
5158
'--version', action='version', version=f'%(prog)s {__version__}'
5259
)
60+
parser.add_argument(
61+
'-w',
62+
'--workers',
63+
type=int,
64+
default=ARG_DEFAULTS['workers'],
65+
help=(
66+
f'The number of workers. Default {ARG_DEFAULTS["workers"]}. '
67+
'Pass 0 to use as many as possible.'
68+
),
69+
)
5370

5471
shape_config_group = parser.add_argument_group(
5572
'Shape Configuration (required)',
@@ -226,33 +243,151 @@ def generate_parser() -> argparse.ArgumentParser:
226243
return parser
227244

228245

229-
def main(argv: Sequence[str] | None = None) -> None:
246+
def _morph(
247+
data: str,
248+
shape: str,
249+
args: argparse.Namespace,
250+
progress: multiprocessing.DictProxy,
251+
task_id: TaskID,
252+
) -> None:
230253
"""
231-
Run Data Morph as a script.
254+
Run the morphing algorithm.
232255
233256
Parameters
234257
----------
235-
argv : Sequence[str] | None, optional
236-
Makes it possible to pass in options without running on
237-
the command line.
258+
data : str
259+
The dataset to use. This can be the name of a built-in dataset or a path to a
260+
CSV file containing the data.
261+
shape : str
262+
The name of the target shape.
263+
args : argparse.Namespace
264+
Command line arguments.
265+
progress : multiprocessing.DictProxy
266+
The state of all task progresses.
267+
task_id : TaskID
268+
The task ID assigned by the progress tracker.
269+
270+
Notes
271+
-----
272+
This should only be used with :func:`._parallelize`.
238273
"""
274+
progress[task_id] = {'progress': 0, 'total': args.iterations}
239275

240-
args = generate_parser().parse_args(argv)
276+
dataset = DataLoader.load_dataset(data, scale=args.scale)
277+
shape = ShapeFactory(dataset).generate_shape(shape)
241278

242-
target_shapes = (
243-
ShapeFactory.AVAILABLE_SHAPES
244-
if args.target_shape == 'all' or 'all' in args.target_shape
245-
else set(args.target_shape).intersection(ShapeFactory.AVAILABLE_SHAPES)
279+
morpher = DataMorpher(
280+
decimals=args.decimals,
281+
output_dir=args.output_dir,
282+
write_data=args.write_data,
283+
seed=args.seed,
284+
keep_frames=args.keep_frames,
285+
forward_only_animation=args.forward_only,
286+
num_frames=100,
287+
in_notebook=False,
246288
)
247-
if not target_shapes:
248-
raise ValueError(
249-
'No valid target shapes were provided. Valid options are '
250-
f"""'{"', '".join(ShapeFactory.AVAILABLE_SHAPES)}'."""
251-
)
252289

290+
_ = morpher.morph(
291+
start_shape=dataset,
292+
target_shape=shape,
293+
iterations=args.iterations,
294+
min_shake=args.shake,
295+
ease_in=args.ease_in or args.ease,
296+
ease_out=args.ease_out or args.ease,
297+
freeze_for=args.freeze,
298+
progress=progress,
299+
task_id=task_id,
300+
)
301+
302+
303+
def _parallelize(
304+
total_jobs: int,
305+
workers: int,
306+
args: argparse.Namespace,
307+
target_shapes: Sequence[str],
308+
) -> None:
309+
"""
310+
Run morphing algorithm in parallel.
311+
312+
Parameters
313+
----------
314+
total_jobs : int
315+
The total number of morphing jobs that need to be run.
316+
workers : int
317+
The number of worker processes to use.
318+
args : argparse.Namespace
319+
The command line arguments.
320+
target_shapes : Sequence[str]
321+
The target shapes for morphing (datasets are in ``args.start_shape``).
322+
"""
323+
only_show_running = total_jobs > workers
324+
325+
with (
326+
DataMorphProgress() as progress_tracker,
327+
multiprocessing.Manager() as manager,
328+
):
329+
task_progress = manager.dict()
330+
overall_progress_task = progress_tracker.add_task('[green]Overall progress')
331+
332+
with ProcessPoolExecutor(max_workers=workers) as executor:
333+
futures = [
334+
executor.submit(
335+
_morph,
336+
dataset,
337+
shape,
338+
args,
339+
task_progress,
340+
progress_tracker.add_task(
341+
f'{Path(dataset).stem} to {shape}', visible=False, start=False
342+
),
343+
)
344+
for dataset, shape in itertools.product(args.start_shape, target_shapes)
345+
]
346+
347+
while True:
348+
finished_jobs = sum(future.done() for future in futures)
349+
progress_tracker.update(
350+
overall_progress_task,
351+
completed=sum(task['progress'] for task in task_progress.values()),
352+
total=total_jobs * args.iterations,
353+
)
354+
for task_id, update_data in task_progress.items():
355+
latest = update_data['progress']
356+
total = update_data['total']
357+
358+
if not latest:
359+
# hack to make the elapsed time accurate for ones that start later on
360+
# this is necessary because rich.progress.Progress is not pickleable
361+
progress_tracker.start_task(task_id)
362+
363+
progress_tracker.update(
364+
task_id,
365+
completed=latest,
366+
total=total,
367+
visible=latest < total
368+
if only_show_running
369+
else latest <= total,
370+
)
371+
if finished_jobs == total_jobs:
372+
break
373+
374+
for future in futures:
375+
future.result()
376+
377+
378+
def _serialize(args: argparse.Namespace, target_shapes: Sequence[str]) -> None:
379+
"""
380+
Run the morphing algorithm serially.
381+
382+
Parameters
383+
----------
384+
args : argparse.Namespace
385+
The command line arguments.
386+
target_shapes : Sequence[str]
387+
The target shapes for morphing (datasets are in ``args.start_shape``).
388+
"""
253389
for start_shape in args.start_shape:
254390
dataset = DataLoader.load_dataset(start_shape, scale=args.scale)
255-
print(f"Processing starter shape '{dataset.name}'", file=sys.stderr)
256391

257392
shape_factory = ShapeFactory(dataset)
258393
morpher = DataMorpher(
@@ -266,10 +401,7 @@ def main(argv: Sequence[str] | None = None) -> None:
266401
in_notebook=False,
267402
)
268403

269-
total_shapes = len(target_shapes)
270-
for i, target_shape in enumerate(target_shapes, start=1):
271-
if total_shapes > 1:
272-
print(f'Morphing shape {i} of {total_shapes}', file=sys.stderr)
404+
for target_shape in target_shapes:
273405
_ = morpher.morph(
274406
start_shape=dataset,
275407
target_shape=shape_factory.generate_shape(target_shape),
@@ -279,3 +411,38 @@ def main(argv: Sequence[str] | None = None) -> None:
279411
ease_out=args.ease_out or args.ease,
280412
freeze_for=args.freeze,
281413
)
414+
415+
416+
def main(argv: Sequence[str] | None = None) -> None:
417+
"""
418+
Run Data Morph as a script.
419+
420+
Parameters
421+
----------
422+
argv : Sequence[str] | None, optional
423+
Makes it possible to pass in options without running on
424+
the command line.
425+
"""
426+
427+
args = generate_parser().parse_args(argv)
428+
429+
target_shapes = (
430+
ShapeFactory.AVAILABLE_SHAPES
431+
if args.target_shape == 'all' or 'all' in args.target_shape
432+
else set(args.target_shape).intersection(ShapeFactory.AVAILABLE_SHAPES)
433+
)
434+
if not target_shapes:
435+
raise ValueError(
436+
'No valid target shapes were provided. Valid options are '
437+
f"""'{"', '".join(ShapeFactory.AVAILABLE_SHAPES)}'."""
438+
)
439+
440+
total_jobs = len(args.start_shape) * len(target_shapes)
441+
442+
max_workers = multiprocessing.cpu_count()
443+
workers = max_workers if not args.workers else min(args.workers, max_workers)
444+
445+
if total_jobs > 1 and workers > 1:
446+
_parallelize(total_jobs, workers, args, target_shapes)
447+
else:
448+
_serialize(args, target_shapes)

0 commit comments

Comments
 (0)