Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions rsopt/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from mpi4py import MPI
MPI.Init()
40 changes: 37 additions & 3 deletions rsopt/mpi.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,52 @@
active_env = None

def get_mpi_environment():
global active_env

# Test for mpi4py install
try:
import mpi4py
mpi4py.rc.initialize = False
from mpi4py import MPI
except ModuleNotFoundError:
# mpi4py not installed so it can't be used
return
return None

# If we already ran this process, return the active environment
if active_env:
return active_env

from inspect import currentframe, getframeinfo
frameinfo = getframeinfo(currentframe())
print(f"Initializing MPI from {frameinfo.filename}:L{frameinfo.lineno}", flush=True)

#import faulthandler
#import sys
#faulthandler.enable(file=sys.stderr, all_threads=True)

# Test MPI intialization in another thread
import subprocess
import os
import rsopt
fname = os.path.dirname(rsopt.__file__) + "/__main__.py"
pp = subprocess.run(["python", fname])

if pp.returncode != 0:
return None

# Should already be initialized
# MPI.Init()

if not MPI.COMM_WORLD.Get_size() - 1:
# MPI not being used
# (if user did start MPI with size 1 this would be an illegal configuration since: main + 1 worker = 2 ranks)
return
return None

nworkers = MPI.COMM_WORLD.Get_size() - 1
is_manager = MPI.COMM_WORLD.Get_rank() == 0
mpi_environment = {'mpi_comm': MPI.COMM_WORLD, 'comms': 'mpi', 'nworkers': nworkers, 'is_manager': is_manager}

return mpi_environment
# Save global environment
active_env = mpi_environment

return mpi_environment
20 changes: 8 additions & 12 deletions rsopt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import numpy as np
import pickle
from libensemble.tools import save_libE_output
from .mpi import active_env as MPI_ENV
from .mpi import get_mpi_environment

SLURM_PREFIX = 'nid'

Expand Down Expand Up @@ -54,12 +56,9 @@ def return_nodelist(nodelist_string):

def return_used_nodes():
"""Returns all used processor names to rank 0 or an empty list if MPI not used. For ranks != 0 returns None."""
try:
from mpi4py import MPI
except ModuleNotFoundError:
# If MPI not being used to start rsopt then no nodes will have srun executed yet
return []

if not MPI_ENV:
get_mpi_environment()

rank = MPI.COMM_WORLD.Get_rank()
name = MPI.Get_processor_name()
all_names = MPI.COMM_WORLD.gather(name, root=0)
Expand Down Expand Up @@ -92,11 +91,8 @@ def return_unused_node():

def broadcast(data, root_rank=0):
"""broadcast, or don't bother"""
try:
from mpi4py import MPI
except ModuleNotFoundError:
# If MPI not available for import then assume it isn't needed
return data
if not MPI_ENV:
get_mpi_environment()

if MPI.COMM_WORLD.Get_size() == 1:
return data
Expand All @@ -121,4 +117,4 @@ def _libe_save(H, persis_info, mess, filename):
np.save(filename, H)

with open(filename + ".pickle", "wb") as f:
pickle.dump(persis_info, f)
pickle.dump(persis_info, f)