Skip to content

Commit 8bc1b7f

Browse files
committed
Add batch documentation to readme
1 parent abfbe4d commit 8bc1b7f

File tree

2 files changed

+228
-0
lines changed

2 files changed

+228
-0
lines changed

README.md

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,150 @@ with MPILock(MPI.COMM_WORLD) as mpilock:
186186
mpilock.unlock()
187187
```
188188

189+
## MPIBatch Class
190+
191+
This class is useful for a common pattern where a pool of "workers", each
192+
consisting of multiple processes, needs to complete "tasks" of varying size. The
193+
fact that each worker is a group of processes means we cannot use the standard
194+
MPIPoolExecutor from mpi4py. This scenario also means that it is difficult to
195+
recover from errors like a segmentation fault (since there is no "nanny"
196+
process). Still, there are many applications where groups of processes need to
197+
work on a fixed number of tasks and dynamically assign those tasks to a smaller
198+
number of workers.
199+
200+
By default, the state of tasks are tracked purely in MPI shared memory. The
201+
calling code is responsible for initializing the state of each task from
202+
external information. There is also support for simple use of the filesystem for
203+
state tracking in addition to the in-memory copy. This assumes a top-level
204+
directory with a subdirectory for each task. A special "state" file is created
205+
in each task directory and the `MPIBatch` class ensures that only one process at
206+
a time modifies that state file and the in-memory copy. Using the filesystem can
207+
also help when running multiple batch instances that are working on the same set
208+
of tasks.
209+
210+
211+
### Example
212+
213+
Here is an example using `MPIBatch` to track the state of tasks using the
214+
filesystem (not just in memory). For that use case, the tasks must have
215+
a "name" which is used as a subdirectory. Note that if you run this script
216+
twice, make sure to remove the output directory- otherwise nothing will
217+
happen since all tasks are done.
218+
219+
```python
220+
import random
221+
import time
222+
import numpy as np
223+
from mpi4py import MPI
224+
225+
from pshmem import MPIBatch
226+
227+
comm = MPI.COMM_WORLD
228+
229+
def fake_task_work(wrk_comm):
230+
"""A function which emulates the work for a single task.
231+
"""
232+
# All processes in the worker group so something.
233+
slp = 0.2 + 0.2 * random.random()
234+
time.sleep(slp)
235+
# Wait until everyone in the group is done.
236+
if wrk_comm is not None:
237+
wrk_comm.barrier()
238+
239+
ntask = 10
240+
241+
# The top-level directory
242+
task_dir = "test"
243+
244+
# The "names" (subdirectories) of each task
245+
task_names = [f"task_{x:03d}" for x in range(ntask)]
246+
247+
# Two workers
248+
worker_size = 1
249+
if comm.size > 1:
250+
worker_size = comm.size // 2
251+
252+
# Create the batch system to track the state of tasks.
253+
batch = MPIBatch(
254+
comm,
255+
worker_size,
256+
ntask,
257+
task_fs_root=task_dir,
258+
task_fs_names=task_names,
259+
)
260+
261+
# Track the tasks executed by each worker to so we can
262+
# display that at the end. This variable is only for
263+
# purposes of printing.
264+
proc_tasks = batch.INVALID * np.ones(ntask, dtype=np.int32)
265+
266+
# Workers loop over tasks until there are no more left.
267+
task = -1
268+
while task is not None:
269+
task = batch.next_task()
270+
if task is None:
271+
# Nothing left for this worker
272+
break
273+
try:
274+
proc_tasks[task] = batch.RUNNING
275+
fake_task_work(batch.worker_comm)
276+
if batch.worker_rank == 0:
277+
# Only one process in the worker group needs
278+
# to update the state.
279+
batch.set_task_state(task, batch.DONE)
280+
proc_tasks[task] = batch.DONE
281+
except Exception:
282+
# The task raised an exception, mark this task
283+
# as failed.
284+
if batch.worker_rank == 0:
285+
# Only one process in the worker group needs
286+
# to update the state.
287+
batch.set_task_state(task, batch.FAILED)
288+
proc_tasks[task] = batch.FAILED
289+
290+
# Wait for all workers to finish
291+
comm.barrier()
292+
293+
# Each worker reports on their status
294+
for iwork in range(batch.n_worker):
295+
if iwork == batch.worker:
296+
if batch.worker_rank == 0:
297+
proc_stat = [MPIBatch.state_to_string(x) for x in proc_tasks]
298+
msg = f"Worker {batch.worker} tasks = {proc_stat}"
299+
print(msg, flush=True)
300+
batch.comm.barrier()
301+
302+
# Cleanup
303+
del batch
304+
```
305+
306+
Putting this code into a script called `test_batch.py` and running it
307+
produces:
308+
```
309+
mpirun -np 4 python3 test.py
310+
311+
Worker 0 tasks = ['DONE', 'INVALID', 'INVALID', 'DONE', 'INVALID', 'DONE', 'INVALID', 'DONE', 'DONE', 'INVALID']
312+
Worker 1 tasks = ['INVALID', 'DONE', 'DONE', 'INVALID', 'DONE', 'INVALID', 'DONE', 'INVALID', 'INVALID', 'DONE']
313+
```
314+
315+
So you can see that tasks are assigned to different worker groups as those workers
316+
complete previous tasks. The state is tracked on the filesystem with a `state` file
317+
in each task directory. After running the script above we can look at the contents
318+
of those:
319+
```
320+
cat test/*/state
321+
DONE
322+
DONE
323+
DONE
324+
DONE
325+
DONE
326+
DONE
327+
DONE
328+
DONE
329+
DONE
330+
DONE
331+
```
332+
189333
## Tests
190334

191335
After installation, you can run some tests with:

test_scripts/readme_test_batch.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import random
2+
import time
3+
import numpy as np
4+
from mpi4py import MPI
5+
6+
from pshmem import MPIBatch
7+
8+
comm = MPI.COMM_WORLD
9+
10+
def fake_task_work(wrk_comm):
11+
"""A function which emulates the work for a single task.
12+
"""
13+
# All processes in the worker group so something.
14+
slp = 0.2 + 0.2 * random.random()
15+
time.sleep(slp)
16+
# Wait until everyone in the group is done.
17+
if wrk_comm is not None:
18+
wrk_comm.barrier()
19+
20+
ntask = 10
21+
22+
# The top-level directory
23+
task_dir = "test"
24+
25+
# The "names" (subdirectories) of each task
26+
task_names = [f"task_{x:03d}" for x in range(ntask)]
27+
28+
# Two workers
29+
worker_size = 1
30+
if comm.size > 1:
31+
worker_size = comm.size // 2
32+
33+
# Create the batch system to track the state of tasks.
34+
batch = MPIBatch(
35+
comm,
36+
worker_size,
37+
ntask,
38+
task_fs_root=task_dir,
39+
task_fs_names=task_names,
40+
)
41+
42+
# Track the tasks executed by each worker to so we can
43+
# display that at the end. This variable is only for
44+
# purposes of printing.
45+
proc_tasks = batch.INVALID * np.ones(ntask, dtype=np.int32)
46+
47+
# Workers loop over tasks until there are no more left.
48+
task = -1
49+
while task is not None:
50+
task = batch.next_task()
51+
if task is None:
52+
# Nothing left for this worker
53+
break
54+
try:
55+
proc_tasks[task] = batch.RUNNING
56+
fake_task_work(batch.worker_comm)
57+
if batch.worker_rank == 0:
58+
# Only one process in the worker group needs
59+
# to update the state.
60+
batch.set_task_state(task, batch.DONE)
61+
proc_tasks[task] = batch.DONE
62+
except Exception:
63+
# The task raised an exception, mark this task
64+
# as failed.
65+
if batch.worker_rank == 0:
66+
# Only one process in the worker group needs
67+
# to update the state.
68+
batch.set_task_state(task, batch.FAILED)
69+
proc_tasks[task] = batch.FAILED
70+
71+
# Wait for all workers to finish
72+
comm.barrier()
73+
74+
# Each worker reports on their status
75+
for iwork in range(batch.n_worker):
76+
if iwork == batch.worker:
77+
if batch.worker_rank == 0:
78+
proc_stat = [MPIBatch.state_to_string(x) for x in proc_tasks]
79+
msg = f"Worker {batch.worker} tasks = {proc_stat}"
80+
print(msg, flush=True)
81+
batch.comm.barrier()
82+
83+
# Cleanup
84+
del batch

0 commit comments

Comments
 (0)