2
2
import sys
3
3
from multiprocessing import Pool , cpu_count , Manager
4
4
import time
5
- from typing import Callable , List , Any
6
- from threading import Thread , Event , Lock
5
+ from typing import Callable , List , Any , Union
6
+ from threading import Lock , Thread , Event
7
7
import shutil
8
8
9
+ from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments
10
+
11
+
9
12
class MonitoredFunction :
10
- def __init__ (self , fn : Callable , running_tasks : ListProxy , updated_repos : ValueProxy , lock : Lock ):
13
+ def __init__ (
14
+ self ,
15
+ fn : Callable ,
16
+ running_tasks : ListProxy ,
17
+ updated_repos : ValueProxy ,
18
+ lock : Lock
19
+ ):
11
20
self .fn = fn
12
21
self .running_tasks = running_tasks
13
22
self .updated_repos = updated_repos
14
23
self ._lock = lock
15
24
16
- def __call__ (self , * args ):
17
- task_name = args [0 ][ 2 ]
25
+ def __call__ (self , * args : Union [ RunnerArguments , AdditionalSwiftSourcesArguments ] ):
26
+ task_name = args [0 ]. repo_name
18
27
self .running_tasks .append (task_name )
28
+ result = None
19
29
try :
20
- return self .fn (* args )
30
+ result = self .fn (* args )
31
+ except Exception as e :
32
+ print (e )
21
33
finally :
22
34
self ._lock .acquire ()
23
- self .running_tasks .remove (task_name )
35
+ if task_name in self .running_tasks :
36
+ self .running_tasks .remove (task_name )
24
37
self .updated_repos .set (self .updated_repos .get () + 1 )
25
38
self ._lock .release ()
39
+ return result
26
40
27
41
28
42
class ParallelRunner :
29
- def __init__ (self , fn : Callable , pool_args : List [List [Any ]], n_processes : int = 0 ):
43
+ def __init__ (
44
+ self ,
45
+ fn : Callable ,
46
+ pool_args : List [Union [RunnerArguments , AdditionalSwiftSourcesArguments ]],
47
+ n_processes : int = 0 ,
48
+ ):
30
49
self ._monitor_polling_period = 0.1
31
50
if n_processes == 0 :
32
51
n_processes = cpu_count () * 2
33
52
self ._terminal_width = shutil .get_terminal_size ().columns
34
53
self ._n_processes = n_processes
35
54
self ._pool_args = pool_args
55
+ manager = Manager ()
56
+ self ._lock = manager .Lock ()
57
+ self ._running_tasks = manager .list ()
58
+ self ._updated_repos = manager .Value ("i" , 0 )
36
59
self ._fn = fn
37
- self ._lock = Manager ().Lock ()
38
- self ._pool = Pool (
39
- processes = self ._n_processes , initializer = self ._child_init , initargs = (self ._lock ,)
40
- )
41
- self ._verbose = pool_args [0 ][len (pool_args [0 ]) - 1 ]
60
+ self ._pool = Pool (processes = self ._n_processes )
61
+ self ._verbose = pool_args [0 ].verbose
62
+ self ._output_prefix = pool_args [0 ].output_prefix
42
63
self ._nb_repos = len (pool_args )
43
64
self ._stop_event = Event ()
44
- self ._running_tasks = Manager (). list ()
45
- self ._updated_repos = Manager (). Value ( 'i' , 0 )
46
- self . _monitored_fn = MonitoredFunction ( self . _fn , self . _running_tasks , self . _updated_repos , self . _lock )
65
+ self ._monitored_fn = MonitoredFunction (
66
+ self ._fn , self . _running_tasks , self . _updated_repos , self . _lock
67
+ )
47
68
48
69
def run (self ) -> List [Any ]:
49
70
print (
50
71
"Running ``%s`` with up to %d processes."
51
72
% (self ._fn .__name__ , self ._n_processes )
52
73
)
53
-
54
74
if self ._verbose :
55
75
results = self ._pool .map_async (
56
76
func = self ._fn , iterable = self ._pool_args
57
- ).get ()
77
+ ).get (timeout = 1800 )
58
78
self ._pool .close ()
59
79
self ._pool .join ()
60
80
else :
61
81
monitor_thread = Thread (target = self ._monitor , daemon = True )
62
82
monitor_thread .start ()
63
83
results = self ._pool .map_async (
64
84
func = self ._monitored_fn , iterable = self ._pool_args
65
- ).get ()
85
+ ).get (timeout = 1800 )
66
86
self ._pool .close ()
67
87
self ._pool .join ()
68
88
self ._stop_event .set ()
@@ -72,11 +92,14 @@ def run(self) -> List[Any]:
72
92
def _monitor (self ):
73
93
last_output = ""
74
94
while not self ._stop_event .is_set ():
95
+ self ._lock .acquire ()
75
96
current = list (self ._running_tasks )
76
97
current_line = ", " .join (current )
98
+ updated_repos = self ._updated_repos .get ()
99
+ self ._lock .release ()
77
100
78
101
if current_line != last_output :
79
- truncated = f"Updating [ { self ._updated_repos . get () } /{ self ._nb_repos } ] ({ current_line } )"
102
+ truncated = ( f" { self ._output_prefix } [ { updated_repos } /{ self ._nb_repos } ] ({ current_line } )")
80
103
if len (truncated ) > self ._terminal_width :
81
104
ellipsis_marker = " ..."
82
105
truncated = (
@@ -89,17 +112,11 @@ def _monitor(self):
89
112
90
113
time .sleep (self ._monitor_polling_period )
91
114
92
- sys .stdout .write ("\r " + " " * len (last_output ) + "\r " )
115
+ sys .stdout .write ("\r " + " " * len (last_output ) + "\r \n " )
93
116
sys .stdout .flush ()
94
117
95
118
@staticmethod
96
- def _clear_lines (n ):
97
- for _ in range (n ):
98
- sys .stdout .write ("\x1b [1A" )
99
- sys .stdout .write ("\x1b [2K" )
100
-
101
- @staticmethod
102
- def check_results (results , op ):
119
+ def check_results (results , op ) -> int :
103
120
"""Function used to check the results of ParallelRunner.
104
121
105
122
NOTE: This function was originally located in the shell module of
@@ -123,7 +140,3 @@ def check_results(results, op):
123
140
print (r .stderr .decode ())
124
141
return fail_count
125
142
126
- @staticmethod
127
- def _child_init (lck ):
128
- global lock
129
- lock = lck
0 commit comments