33from __future__ import annotations
44
55import argparse
6- import sys
6+ import itertools
7+ import multiprocessing
8+ from concurrent .futures import ProcessPoolExecutor
9+ from pathlib import Path
710from typing import TYPE_CHECKING
811
912from . import __version__
1013from .data .loader import DataLoader
1114from .morpher import DataMorpher
15+ from .progress import DataMorphProgress
1216from .shapes .factory import ShapeFactory
1317
1418if TYPE_CHECKING :
1519 from collections .abc import Sequence
1620
21+ from rich .progress import TaskID
22+
1723ARG_DEFAULTS = {
1824 'output_dir' : 'morphed_data' ,
1925 'decimals' : 2 ,
2026 'min_shake' : 0.3 ,
2127 'iterations' : 100_000 ,
2228 'freeze' : 0 ,
29+ 'workers' : 2 ,
2330}
2431
2532
@@ -50,6 +57,16 @@ def generate_parser() -> argparse.ArgumentParser:
5057 parser .add_argument (
5158 '--version' , action = 'version' , version = f'%(prog)s { __version__ } '
5259 )
60+ parser .add_argument (
61+ '-w' ,
62+ '--workers' ,
63+ type = int ,
64+ default = ARG_DEFAULTS ['workers' ],
65+ help = (
66+ f'The number of workers. Default { ARG_DEFAULTS ["workers" ]} . '
67+ 'Pass 0 to use as many as possible.'
68+ ),
69+ )
5370
5471 shape_config_group = parser .add_argument_group (
5572 'Shape Configuration (required)' ,
@@ -226,33 +243,151 @@ def generate_parser() -> argparse.ArgumentParser:
226243 return parser
227244
228245
229- def main (argv : Sequence [str ] | None = None ) -> None :
246+ def _morph (
247+ data : str ,
248+ shape : str ,
249+ args : argparse .Namespace ,
250+ progress : multiprocessing .DictProxy ,
251+ task_id : TaskID ,
252+ ) -> None :
230253 """
231- Run Data Morph as a script .
254+ Run the morphing algorithm .
232255
233256 Parameters
234257 ----------
235- argv : Sequence[str] | None, optional
236- Makes it possible to pass in options without running on
237- the command line.
258+ data : str
259+ The dataset to use. This can be the name of a built-in dataset or a path to a
260+ CSV file containing the data.
261+ shape : str
262+ The name of the target shape.
263+ args : argparse.Namespace
264+ Command line arguments.
265+ progress : multiprocessing.DictProxy
266+ The state of all task progresses.
267+ task_id : TaskID
268+ The task ID assigned by the progress tracker.
269+
270+ Notes
271+ -----
272+ This should only be used with :func:`._parallelize`.
238273 """
274+ progress [task_id ] = {'progress' : 0 , 'total' : args .iterations }
239275
240- args = generate_parser ().parse_args (argv )
276+ dataset = DataLoader .load_dataset (data , scale = args .scale )
277+ shape = ShapeFactory (dataset ).generate_shape (shape )
241278
242- target_shapes = (
243- ShapeFactory .AVAILABLE_SHAPES
244- if args .target_shape == 'all' or 'all' in args .target_shape
245- else set (args .target_shape ).intersection (ShapeFactory .AVAILABLE_SHAPES )
279+ morpher = DataMorpher (
280+ decimals = args .decimals ,
281+ output_dir = args .output_dir ,
282+ write_data = args .write_data ,
283+ seed = args .seed ,
284+ keep_frames = args .keep_frames ,
285+ forward_only_animation = args .forward_only ,
286+ num_frames = 100 ,
287+ in_notebook = False ,
246288 )
247- if not target_shapes :
248- raise ValueError (
249- 'No valid target shapes were provided. Valid options are '
250- f"""'{ "', '" .join (ShapeFactory .AVAILABLE_SHAPES )} '."""
251- )
252289
290+ _ = morpher .morph (
291+ start_shape = dataset ,
292+ target_shape = shape ,
293+ iterations = args .iterations ,
294+ min_shake = args .shake ,
295+ ease_in = args .ease_in or args .ease ,
296+ ease_out = args .ease_out or args .ease ,
297+ freeze_for = args .freeze ,
298+ progress = progress ,
299+ task_id = task_id ,
300+ )
301+
302+
303+ def _parallelize (
304+ total_jobs : int ,
305+ workers : int ,
306+ args : argparse .Namespace ,
307+ target_shapes : Sequence [str ],
308+ ) -> None :
309+ """
310+ Run morphing algorithm in parallel.
311+
312+ Parameters
313+ ----------
314+ total_jobs : int
315+ The total number of morphing jobs that need to be run.
316+ workers : int
317+ The number of worker processes to use.
318+ args : argparse.Namespace
319+ The command line arguments.
320+ target_shapes : Sequence[str]
321+ The target shapes for morphing (datasets are in ``args.start_shape``).
322+ """
323+ only_show_running = total_jobs > workers
324+
325+ with (
326+ DataMorphProgress () as progress_tracker ,
327+ multiprocessing .Manager () as manager ,
328+ ):
329+ task_progress = manager .dict ()
330+ overall_progress_task = progress_tracker .add_task ('[green]Overall progress' )
331+
332+ with ProcessPoolExecutor (max_workers = workers ) as executor :
333+ futures = [
334+ executor .submit (
335+ _morph ,
336+ dataset ,
337+ shape ,
338+ args ,
339+ task_progress ,
340+ progress_tracker .add_task (
341+ f'{ Path (dataset ).stem } to { shape } ' , visible = False , start = False
342+ ),
343+ )
344+ for dataset , shape in itertools .product (args .start_shape , target_shapes )
345+ ]
346+
347+ while True :
348+ finished_jobs = sum (future .done () for future in futures )
349+ progress_tracker .update (
350+ overall_progress_task ,
351+ completed = sum (task ['progress' ] for task in task_progress .values ()),
352+ total = total_jobs * args .iterations ,
353+ )
354+ for task_id , update_data in task_progress .items ():
355+ latest = update_data ['progress' ]
356+ total = update_data ['total' ]
357+
358+ if not latest :
359+ # hack to make the elapsed time accurate for ones that start later on
360+ # this is necessary because rich.progress.Progress is not pickleable
361+ progress_tracker .start_task (task_id )
362+
363+ progress_tracker .update (
364+ task_id ,
365+ completed = latest ,
366+ total = total ,
367+ visible = latest < total
368+ if only_show_running
369+ else latest <= total ,
370+ )
371+ if finished_jobs == total_jobs :
372+ break
373+
374+ for future in futures :
375+ future .result ()
376+
377+
378+ def _serialize (args : argparse .Namespace , target_shapes : Sequence [str ]) -> None :
379+ """
380+ Run the morphing algorithm serially.
381+
382+ Parameters
383+ ----------
384+ args : argparse.Namespace
385+ The command line arguments.
386+ target_shapes : Sequence[str]
387+ The target shapes for morphing (datasets are in ``args.start_shape``).
388+ """
253389 for start_shape in args .start_shape :
254390 dataset = DataLoader .load_dataset (start_shape , scale = args .scale )
255- print (f"Processing starter shape '{ dataset .name } '" , file = sys .stderr )
256391
257392 shape_factory = ShapeFactory (dataset )
258393 morpher = DataMorpher (
@@ -266,10 +401,7 @@ def main(argv: Sequence[str] | None = None) -> None:
266401 in_notebook = False ,
267402 )
268403
269- total_shapes = len (target_shapes )
270- for i , target_shape in enumerate (target_shapes , start = 1 ):
271- if total_shapes > 1 :
272- print (f'Morphing shape { i } of { total_shapes } ' , file = sys .stderr )
404+ for target_shape in target_shapes :
273405 _ = morpher .morph (
274406 start_shape = dataset ,
275407 target_shape = shape_factory .generate_shape (target_shape ),
@@ -279,3 +411,38 @@ def main(argv: Sequence[str] | None = None) -> None:
279411 ease_out = args .ease_out or args .ease ,
280412 freeze_for = args .freeze ,
281413 )
414+
415+
416+ def main (argv : Sequence [str ] | None = None ) -> None :
417+ """
418+ Run Data Morph as a script.
419+
420+ Parameters
421+ ----------
422+ argv : Sequence[str] | None, optional
423+ Makes it possible to pass in options without running on
424+ the command line.
425+ """
426+
427+ args = generate_parser ().parse_args (argv )
428+
429+ target_shapes = (
430+ ShapeFactory .AVAILABLE_SHAPES
431+ if args .target_shape == 'all' or 'all' in args .target_shape
432+ else set (args .target_shape ).intersection (ShapeFactory .AVAILABLE_SHAPES )
433+ )
434+ if not target_shapes :
435+ raise ValueError (
436+ 'No valid target shapes were provided. Valid options are '
437+ f"""'{ "', '" .join (ShapeFactory .AVAILABLE_SHAPES )} '."""
438+ )
439+
440+ total_jobs = len (args .start_shape ) * len (target_shapes )
441+
442+ max_workers = multiprocessing .cpu_count ()
443+ workers = max_workers if not args .workers else min (args .workers , max_workers )
444+
445+ if total_jobs > 1 and workers > 1 :
446+ _parallelize (total_jobs , workers , args , target_shapes )
447+ else :
448+ _serialize (args , target_shapes )
0 commit comments