Skip to content

Commit 4d842c8

Browse files
authored
Merge pull request #2656 from ekouts/slurm_cpus_per_task_bugfix
[bugfix] Pass explicitly the `--cpus-per-task` option to `srun` for Slurm >= 22.05
2 parents 93de6f4 + 8d4938f commit 4d842c8

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

reframe/core/launchers/mpi.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,52 @@
33
#
44
# SPDX-License-Identifier: BSD-3-Clause
55

6+
import semver
7+
import re
8+
9+
import reframe.utility.osext as osext
610
from reframe.core.backends import register_launcher
711
from reframe.core.launchers import JobLauncher
12+
from reframe.core.logging import getlogger
813
from reframe.utility import seconds_to_hms
914

1015

1116
@register_launcher('srun')
1217
class SrunLauncher(JobLauncher):
18+
def __init__(self):
19+
self.options = []
20+
self.use_cpus_per_task = True
21+
try:
22+
out = osext.run_command('srun --version')
23+
match = re.search('slurm (\d+)\.(\d+)\.(\d+)', out.stdout)
24+
if match:
25+
# We cannot pass to semver strings like 22.05.1 directly
26+
# because it is not a valid version string for semver. We
27+
# need to remove all the leading zeros.
28+
slurm_version = (
29+
semver.VersionInfo(
30+
match.group(1), match.group(2), match.group(3)
31+
)
32+
)
33+
if slurm_version < semver.VersionInfo(22, 5, 0):
34+
self.use_cpus_per_task = False
35+
else:
36+
getlogger().warning(
37+
'could not get version of Slurm, --cpus-per-task will be '
38+
'set according to the num_cpus_per_task attribute'
39+
)
40+
except Exception:
41+
getlogger().warning(
42+
'could not get version of Slurm, --cpus-per-task will be set '
43+
'according to the num_cpus_per_task attribute'
44+
)
45+
1346
def command(self, job):
14-
return ['srun']
47+
ret = ['srun']
48+
if self.use_cpus_per_task and job.num_cpus_per_task:
49+
ret.append(f'--cpus-per-task={job.num_cpus_per_task}')
50+
51+
return ret
1552

1653

1754
@register_launcher('ibrun')

unittests/test_launchers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def minimal_job(make_job, launcher):
104104

105105
def test_run_command(job):
106106
launcher_name = type(job.launcher).registered_name
107+
# This is relevant only for the srun launcher, because it may
108+
# run in different platforms with older versions of Slurm
109+
job.launcher.use_cpus_per_task = True
107110
command = job.launcher.run_command(job)
108111
if launcher_name == 'alps':
109112
assert command == 'aprun -n 4 -N 2 -d 2 -j 0 --foo'
@@ -116,7 +119,7 @@ def test_run_command(job):
116119
elif launcher_name == 'mpirun':
117120
assert command == 'mpirun -np 4 --foo'
118121
elif launcher_name == 'srun':
119-
assert command == 'srun --foo'
122+
assert command == 'srun --cpus-per-task=2 --foo'
120123
elif launcher_name == 'srunalloc':
121124
assert command == ('srun '
122125
'--job-name=fake_job '
@@ -147,6 +150,9 @@ def test_run_command(job):
147150

148151
def test_run_command_minimal(minimal_job):
149152
launcher_name = type(minimal_job.launcher).registered_name
153+
# This is relevant only for the srun launcher, because it may
154+
# run in different platforms with older versions of Slurm
155+
minimal_job.launcher.use_cpus_per_task = True
150156
command = minimal_job.launcher.run_command(minimal_job)
151157
if launcher_name == 'alps':
152158
assert command == 'aprun -n 1 --foo'

0 commit comments

Comments
 (0)