Skip to content

MPI Tutorial

rNLKJA edited this page Mar 21, 2023 · 2 revisions
import math
from mpi4py import MPI

if __name__ == "__main__":
    lst = list(range(1, 11))
    
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    size = comm.Get_size()

    chunk_size = math.ceil(len(lst) / size)

    chunks = [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
    print(chunks)

    result = sum(chunks[rank])

    
    comm.Barrier()
    results = comm.gather(result, root=0)
    
    if rank == 0:
        final_result = sum(results)
        print(final_result)

This code parallelizes the sum of a list of numbers using MPI. Here's a breakdown of how it works:

  1. First, we initialize a list lst with the numbers 1 through 10.
  2. Next, we initialize MPI by creating a new MPI.Comm object comm, and getting the rank and size of the current process.
  3. We calculate the size of each chunk by dividing the length of lst by the number of processes and taking the ceiling of the result to ensure that we don't have any chunks that are too small.
  4. We use a list comprehension to split lst into chunks of size chunk_size using the built-in range() function.
  5. Each process sums its own chunk of the list using the sum() function and stores the result in a variable called result.
  6. We use comm.Barrier() to ensure that all processes have finished summing their chunks before continuing.
  7. We use comm.gather() to collect the results from all processes and store them in a list called results on the root process (rank 0).
  8. Finally, the root process sums the results in results to obtain the final result and prints it.
# MPI as a decorator (not elegant at all)
from mpi4py import MPI
import math

comm = MPI.COMM_WORLD

def mpi_parallelize(func):
    def wrapper(*args, **kwargs):
        lst = args[0]
        
        rank = comm.Get_rank()
        size = comm.Get_size()

        # Split the range of indices to process based on the number of processes
        chunk_size = math.ceil(len(lst) / size)
        chunks = [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]

        # Call the original function with the specified indices
        result = func(chunks[rank], **kwargs)

        comm.Barrier()
        # Gather the results from all processes into a single list
        results = comm.gather(result, root=0)

        # The root process combines the results and returns the final result
        if rank == 0:
            return sum(results)

    return wrapper

@mpi_parallelize
def sum_list(lst):
    return sum(lst)

if __name__ == "__main__":
    lst = list(range(1, 1001))
    final_sum = sum_list(lst)
    
    print(final_sum, sum(lst) == final_sum)

To terminate a running MPI program, you can use the Abort() method of the MPI.COMM_WORLD communicator. This will immediately abort all processes in the communicator and terminate the program.

Here's an example of how to use Abort() to terminate a running MPI program:

from mpi4py import MPI
import time

def main():
    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()

    # Wait for a few seconds before aborting
    if rank == 0:
        print("Starting program, will abort in 5 seconds...")
        time.sleep(5)
        print("Aborting program...")
        comm.Abort()

    # Do some work
    print("Process {} starting work...".format(rank))
    time.sleep(2)
    print("Process {} finished work.".format(rank))

if __name__ == "__main__":
    main()

In this example, we start by initializing MPI and getting the process rank. We then wait for 5 seconds on the root process (rank 0) before calling Abort() to terminate the program. The other processes will continue running until they finish their work or are terminated by Abort().

Note that when you call Abort(), all processes in the communicator will immediately terminate, even if they are in the middle of executing code. This can result in unexpected behavior or incomplete work. Therefore, you should use Abort() with caution and ensure that your code is designed to handle unexpected termination gracefully.

Clone this wiki locally