Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions rsopt/mpi.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@

def get_mpi_environment():
try:
from mpi4py import MPI
import mpi4py
mpi4py.rc.initialize = False
except ModuleNotFoundError:
# mpi4py not installed so it can't be used
return

from mpi4py import MPI
MPI.Init()

if not MPI.COMM_WORLD.Get_size() - 1:
# MPI not being used
# (if user did start MPI with size 1 this would be an illegal configuration since: main + 1 worker = 2 ranks)
Expand All @@ -15,4 +19,4 @@ def get_mpi_environment():
is_manager = MPI.COMM_WORLD.Get_rank() == 0
mpi_environment = {'mpi_comm': MPI.COMM_WORLD, 'comms': 'mpi', 'nworkers': nworkers, 'is_manager': is_manager}

return mpi_environment
return mpi_environment
12 changes: 9 additions & 3 deletions rsopt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ def return_nodelist(nodelist_string):
def return_used_nodes():
"""Returns all used processor names to rank 0 or an empty list if MPI not used. For ranks != 0 returns None."""
try:
from mpi4py import MPI
import mpi4py
mpi4py.rc.initialize = False
except ModuleNotFoundError:
# If MPI not being used to start rsopt then no nodes will have srun executed yet
return []

from mpi4py import MPI
MPI.Init()
rank = MPI.COMM_WORLD.Get_rank()
name = MPI.Get_processor_name()
all_names = MPI.COMM_WORLD.gather(name, root=0)
Expand Down Expand Up @@ -93,11 +96,14 @@ def return_unused_node():
def broadcast(data, root_rank=0):
"""broadcast, or don't bother"""
try:
from mpi4py import MPI
import mpi4py
mpi4py.rc.initialize = False
except ModuleNotFoundError:
# If MPI not available for import then assume it isn't needed
return data

from mpi4py import MPI
MPI.Init()
if MPI.COMM_WORLD.Get_size() == 1:
return data

Expand All @@ -121,4 +127,4 @@ def _libe_save(H, persis_info, mess, filename):
np.save(filename, H)

with open(filename + ".pickle", "wb") as f:
pickle.dump(persis_info, f)
pickle.dump(persis_info, f)