Skip to content

Commit 7909a94

Browse files
[update-checkout] refactor arguments passing
1 parent a76b7e5 commit 7909a94

File tree

5 files changed

+174
-90
lines changed

5 files changed

+174
-90
lines changed

utils/update_checkout/tests/test_clone.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_clone_with_additional_scheme(self):
3636
'--config', self.additional_config_path,
3737
'--source-root', self.source_root,
3838
'--clone',
39-
'--scheme', 'extra'])
39+
'--scheme', 'extra',
40+
'--verbose'])
4041

4142
# Test that we're actually checking out the 'extra' scheme based on the output
4243
self.assertIn(b"git checkout refs/heads/main", output)

utils/update_checkout/tests/test_locked_repository.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,35 @@
11
import unittest
22
from unittest.mock import patch
33

4-
from update_checkout.update_checkout import _is_any_repository_locked
4+
from update_checkout.update_checkout import UpdateArguments, _is_any_repository_locked
5+
6+
7+
def _update_arguments_with_fake_path(repo_name: str, path: str) -> UpdateArguments:
8+
return UpdateArguments(
9+
repo_name=repo_name,
10+
source_root=path,
11+
config={},
12+
scheme_name="",
13+
scheme_map=None,
14+
tag="",
15+
timestamp=None,
16+
reset_to_remote=False,
17+
clean=False,
18+
stash=False,
19+
cross_repos_pr=False,
20+
output_prefix="",
21+
verbose=False,
22+
)
23+
524

625
class TestIsAnyRepositoryLocked(unittest.TestCase):
726
@patch("os.path.exists")
827
@patch("os.path.isdir")
928
@patch("os.listdir")
1029
def test_repository_with_lock_file(self, mock_listdir, mock_isdir, mock_exists):
1130
pool_args = [
12-
("/fake_path", None, "repo1"),
13-
("/fake_path", None, "repo2"),
31+
_update_arguments_with_fake_path("repo1", "/fake_path"),
32+
_update_arguments_with_fake_path("repo2", "/fake_path"),
1433
]
1534

1635
def listdir_side_effect(path):
@@ -32,7 +51,7 @@ def listdir_side_effect(path):
3251
@patch("os.listdir")
3352
def test_repository_without_git_dir(self, mock_listdir, mock_isdir, mock_exists):
3453
pool_args = [
35-
("/fake_path", None, "repo1"),
54+
_update_arguments_with_fake_path("repo1", "/fake_path"),
3655
]
3756

3857
mock_exists.return_value = False
@@ -47,7 +66,7 @@ def test_repository_without_git_dir(self, mock_listdir, mock_isdir, mock_exists)
4766
@patch("os.listdir")
4867
def test_repository_with_git_file(self, mock_listdir, mock_isdir, mock_exists):
4968
pool_args = [
50-
("/fake_path", None, "repo1"),
69+
_update_arguments_with_fake_path("repo1", "/fake_path"),
5170
]
5271

5372
mock_exists.return_value = True
@@ -60,9 +79,11 @@ def test_repository_with_git_file(self, mock_listdir, mock_isdir, mock_exists):
6079
@patch("os.path.exists")
6180
@patch("os.path.isdir")
6281
@patch("os.listdir")
63-
def test_repository_with_multiple_lock_files(self, mock_listdir, mock_isdir, mock_exists):
82+
def test_repository_with_multiple_lock_files(
83+
self, mock_listdir, mock_isdir, mock_exists
84+
):
6485
pool_args = [
65-
("/fake_path", None, "repo1"),
86+
_update_arguments_with_fake_path("repo1", "/fake_path"),
6687
]
6788

6889
mock_exists.return_value = True
@@ -77,7 +98,7 @@ def test_repository_with_multiple_lock_files(self, mock_listdir, mock_isdir, moc
7798
@patch("os.listdir")
7899
def test_repository_with_no_lock_files(self, mock_listdir, mock_isdir, mock_exists):
79100
pool_args = [
80-
("/fake_path", None, "repo1"),
101+
_update_arguments_with_fake_path("repo1", "/fake_path"),
81102
]
82103

83104
mock_exists.return_value = True
@@ -86,4 +107,3 @@ def test_repository_with_no_lock_files(self, mock_listdir, mock_isdir, mock_exis
86107

87108
result = _is_any_repository_locked(pool_args)
88109
self.assertEqual(result, set())
89-

utils/update_checkout/update_checkout/parallel_runner.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,67 +2,87 @@
22
import sys
33
from multiprocessing import Pool, cpu_count, Manager
44
import time
5-
from typing import Callable, List, Any
6-
from threading import Thread, Event, Lock
5+
from typing import Callable, List, Any, Union
6+
from threading import Lock, Thread, Event
77
import shutil
88

9+
from .runner_arguments import RunnerArguments, AdditionalSwiftSourcesArguments
10+
11+
912
class MonitoredFunction:
10-
def __init__(self, fn: Callable, running_tasks: ListProxy, updated_repos: ValueProxy, lock: Lock):
13+
def __init__(
14+
self,
15+
fn: Callable,
16+
running_tasks: ListProxy,
17+
updated_repos: ValueProxy,
18+
lock: Lock
19+
):
1120
self.fn = fn
1221
self.running_tasks = running_tasks
1322
self.updated_repos = updated_repos
1423
self._lock = lock
1524

16-
def __call__(self, *args):
17-
task_name = args[0][2]
25+
def __call__(self, *args: Union[RunnerArguments, AdditionalSwiftSourcesArguments]):
26+
task_name = args[0].repo_name
1827
self.running_tasks.append(task_name)
28+
result = None
1929
try:
20-
return self.fn(*args)
30+
result = self.fn(*args)
31+
except Exception as e:
32+
print(e)
2133
finally:
2234
self._lock.acquire()
23-
self.running_tasks.remove(task_name)
35+
if task_name in self.running_tasks:
36+
self.running_tasks.remove(task_name)
2437
self.updated_repos.set(self.updated_repos.get() + 1)
2538
self._lock.release()
39+
return result
2640

2741

2842
class ParallelRunner:
29-
def __init__(self, fn: Callable, pool_args: List[List[Any]], n_processes: int = 0):
43+
def __init__(
44+
self,
45+
fn: Callable,
46+
pool_args: List[Union[RunnerArguments, AdditionalSwiftSourcesArguments]],
47+
n_processes: int = 0,
48+
):
3049
self._monitor_polling_period = 0.1
3150
if n_processes == 0:
3251
n_processes = cpu_count() * 2
3352
self._terminal_width = shutil.get_terminal_size().columns
3453
self._n_processes = n_processes
3554
self._pool_args = pool_args
55+
manager = Manager()
56+
self._lock = manager.Lock()
57+
self._running_tasks = manager.list()
58+
self._updated_repos = manager.Value("i", 0)
3659
self._fn = fn
37-
self._lock = Manager().Lock()
38-
self._pool = Pool(
39-
processes=self._n_processes, initializer=self._child_init, initargs=(self._lock,)
40-
)
41-
self._verbose = pool_args[0][len(pool_args[0]) - 1]
60+
self._pool = Pool(processes=self._n_processes)
61+
self._verbose = pool_args[0].verbose
62+
self._output_prefix = pool_args[0].output_prefix
4263
self._nb_repos = len(pool_args)
4364
self._stop_event = Event()
44-
self._running_tasks = Manager().list()
45-
self._updated_repos = Manager().Value('i', 0)
46-
self._monitored_fn = MonitoredFunction(self._fn, self._running_tasks, self._updated_repos, self._lock)
65+
self._monitored_fn = MonitoredFunction(
66+
self._fn, self._running_tasks, self._updated_repos, self._lock
67+
)
4768

4869
def run(self) -> List[Any]:
4970
print(
5071
"Running ``%s`` with up to %d processes."
5172
% (self._fn.__name__, self._n_processes)
5273
)
53-
5474
if self._verbose:
5575
results = self._pool.map_async(
5676
func=self._fn, iterable=self._pool_args
57-
).get()
77+
).get(timeout=1800)
5878
self._pool.close()
5979
self._pool.join()
6080
else:
6181
monitor_thread = Thread(target=self._monitor, daemon=True)
6282
monitor_thread.start()
6383
results = self._pool.map_async(
6484
func=self._monitored_fn, iterable=self._pool_args
65-
).get()
85+
).get(timeout=1800)
6686
self._pool.close()
6787
self._pool.join()
6888
self._stop_event.set()
@@ -72,11 +92,14 @@ def run(self) -> List[Any]:
7292
def _monitor(self):
7393
last_output = ""
7494
while not self._stop_event.is_set():
95+
self._lock.acquire()
7596
current = list(self._running_tasks)
7697
current_line = ", ".join(current)
98+
updated_repos = self._updated_repos.get()
99+
self._lock.release()
77100

78101
if current_line != last_output:
79-
truncated = f"Updating [{self._updated_repos.get()}/{self._nb_repos}] ({current_line})"
102+
truncated = (f"{self._output_prefix} [{updated_repos}/{self._nb_repos}] ({current_line})")
80103
if len(truncated) > self._terminal_width:
81104
ellipsis_marker = " ..."
82105
truncated = (
@@ -89,17 +112,11 @@ def _monitor(self):
89112

90113
time.sleep(self._monitor_polling_period)
91114

92-
sys.stdout.write("\r" + " " * len(last_output) + "\r")
115+
sys.stdout.write("\r" + " " * len(last_output) + "\r\n")
93116
sys.stdout.flush()
94117

95118
@staticmethod
96-
def _clear_lines(n):
97-
for _ in range(n):
98-
sys.stdout.write("\x1b[1A")
99-
sys.stdout.write("\x1b[2K")
100-
101-
@staticmethod
102-
def check_results(results, op):
119+
def check_results(results, op) -> int:
103120
"""Function used to check the results of ParallelRunner.
104121
105122
NOTE: This function was originally located in the shell module of
@@ -123,7 +140,3 @@ def check_results(results, op):
123140
print(r.stderr.decode())
124141
return fail_count
125142

126-
@staticmethod
127-
def _child_init(lck):
128-
global lock
129-
lock = lck
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from dataclasses import dataclass
2+
from typing import Any, Dict
3+
4+
@dataclass
5+
class RunnerArguments:
6+
repo_name: str
7+
scheme_name: str
8+
output_prefix: str
9+
verbose: bool
10+
11+
@dataclass
12+
class UpdateArguments(RunnerArguments):
13+
source_root: str
14+
config: Dict[str, Any]
15+
scheme_map: Any
16+
tag: str
17+
timestamp: Any
18+
reset_to_remote: bool
19+
clean: bool
20+
stash: bool
21+
cross_repos_pr: bool
22+
23+
@dataclass
24+
class AdditionalSwiftSourcesArguments(RunnerArguments):
25+
args: RunnerArguments
26+
repo_info: str
27+
repo_branch: str
28+
remote: str
29+
with_ssh: bool
30+
skip_history: bool
31+
skip_tags: bool
32+
skip_repository_list: bool
33+
use_submodules: bool

0 commit comments

Comments
 (0)