Skip to content

Commit ccfe332

Browse files
d4l3kfacebook-github-bot
authored andcommitted
torchx/schedulers: add support for slurm (#87)
Summary: Pull Request resolved: #87 SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects that slurm CLI tools are locally installed and job accounting is enabled. Each app def is scheduled using a heterogenous job via sbatch. Each replica of each role has a unique shell script generated with it's resource allocations and args and then sbatch is used to launch all of them together. Logs are written to the default slurm log file. For more info see: * https://slurm.schedmd.com/sbatch.html * https://slurm.schedmd.com/heterogeneous_jobs.html ``` $ torchx run --scheduler slurm utils.echo --msg hello slurm://torchx_user/1234 $ torchx status slurm://torchx_user/1234 $ less slurm-1234.out ``` Pull Request resolved: #78 Test Plan: Setup test slurm cluster with job accounting enabled ``` $ python setup.py bdist_wheel; and scp dist/torchx-0.1.0.dev2-py3-none-any.whl user@host: $ ssh user@host $ python3.8 -m pip install --user ./torchx-0.1.0.dev2-py3-none-any.whl $ torchx run --scheduler slurm tests.echo {"session": "", "scheduler": "slurm", "api": "schedule", "app_id": "55", "runcfg": "{}", "raw_exception": null, "source": "<unknown>"} === RUN RESULT === Launched app: slurm://torchx_ubuntu/55 {"session": "55", "scheduler": "slurm", "api": "status", "app_id": "55", "runcfg": null, "raw_exception": null, "source": "<unknown>"} App status: { "state": 0, "num_restarts": -1, "msg": "<NONE>", "ui_url": null, "roles": [], "structured_error_msg": "<NONE>" } $ torchx status slurm://torchx_ubuntu/55 {"session": "55", "scheduler": "slurm", "api": "status", "app_id": "55", "runcfg": null, "raw_exception": null, "source": "<unknown>"} AppDef: State: SUCCEEDED Num Restarts: -1 Roles: $ cat slurm-55.out hello world ``` Reviewed By: tierex Differential Revision: D29282763 Pulled By: kiukchung fbshipit-source-id: 576fadf3ccea6e1a87b567692280fac4f43401d2
1 parent d88972d commit ccfe332

File tree

5 files changed

+443
-1
lines changed

5 files changed

+443
-1
lines changed

docs/source/schedulers/slurm.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
Slurm
22
=================
3-
<COMING SOON>
43

4+
.. automodule:: torchx.schedulers.slurm_scheduler
5+
.. currentmodule:: torchx.schedulers.slurm_scheduler
6+
7+
.. autoclass:: SlurmScheduler
8+
:members:

torchx/cli/cmd_run.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,12 @@ def add_arguments(self, subparser: argparse.ArgumentParser) -> None:
152152
help="Does not actually submit the app,"
153153
" just prints the scheduler request",
154154
)
155+
subparser.add_argument(
156+
"--wait",
157+
action="store_true",
158+
default=False,
159+
help="Wait for the app to finish before exiting.",
160+
)
155161
subparser.add_argument(
156162
"conf_file",
157163
type=str,
@@ -183,3 +189,7 @@ def run(self, args: argparse.Namespace) -> None:
183189
status = runner.status(app_handle)
184190
print(f"App status: {status}")
185191
print(f"Job URL: {none_throws(status).ui_url}")
192+
193+
if args.wait:
194+
print("Waiting for the app to finish...")
195+
runner.wait(app_handle)

torchx/schedulers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Dict
99

1010
import torchx.schedulers.local_scheduler as local_scheduler
11+
import torchx.schedulers.slurm_scheduler as slurm_scheduler
1112
from torchx.schedulers.api import Scheduler
1213
from torchx.specs.api import SchedulerBackend
1314
from torchx.util.entrypoints import load_group
@@ -24,6 +25,7 @@ def get_schedulers(
2425
default={
2526
"local": local_scheduler.create_scheduler,
2627
"default": local_scheduler.create_scheduler,
28+
"slurm": slurm_scheduler.create_scheduler,
2729
},
2830
ignore_missing=True,
2931
)

torchx/schedulers/slurm_scheduler.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import csv
9+
import os.path
10+
import shlex
11+
import subprocess
12+
import tempfile
13+
from dataclasses import dataclass
14+
from typing import Any, Dict, List, Mapping, Optional
15+
16+
from torchx.schedulers.api import AppDryRunInfo, DescribeAppResponse, Scheduler
17+
from torchx.specs.api import (
18+
NONE,
19+
AppDef,
20+
AppState,
21+
Role,
22+
RunConfig,
23+
SchedulerBackend,
24+
macros,
25+
)
26+
27+
28+
SLURM_STATES: Mapping[str, AppState] = {
29+
"BOOT_FAIL": AppState.FAILED,
30+
"CANCELLED": AppState.CANCELLED,
31+
"COMPLETED": AppState.SUCCEEDED,
32+
"DEADLINE": AppState.FAILED,
33+
"FAILED": AppState.FAILED,
34+
"NODE_FAIL": AppState.FAILED,
35+
"OUT_OF_MEMORY": AppState.FAILED,
36+
"PENDING": AppState.PENDING,
37+
"PREEMPTED": AppState.FAILED,
38+
"RUNNING": AppState.RUNNING,
39+
"REQUEUED": AppState.PENDING,
40+
"RESIZING": AppState.PENDING,
41+
"REVOKED": AppState.FAILED,
42+
"SUSPENDED": AppState.PENDING,
43+
"TIMEOUT": AppState.FAILED,
44+
}
45+
46+
47+
def _slurm_escape(s: str) -> str:
48+
"""
49+
_slurm_escape escapes the argument and substitutes in the macros.app_id with
50+
a shell expression that fills in SLURM_JOB_ID from env.
51+
"""
52+
escaped_parts = [shlex.quote(part) for part in s.split(macros.app_id)]
53+
return '"$SLURM_JOB_ID"'.join(escaped_parts)
54+
55+
56+
@dataclass
57+
class SlurmReplicaRequest:
58+
"""
59+
Holds parameters for a single replica running on slurm and can be materialized down to a bash script.
60+
"""
61+
62+
dir: str
63+
entrypoint: str
64+
args: List[str]
65+
opts: Dict[str, str]
66+
env: Dict[str, str]
67+
68+
@classmethod
69+
def from_role(cls, role: Role, cfg: RunConfig) -> "SlurmReplicaRequest":
70+
opts = {k: str(v) for k, v in cfg.cfgs.items()}
71+
72+
if (resource := role.resource) != NONE:
73+
if (cpu := resource.cpu) > 0:
74+
opts["cpus-per-task"] = str(cpu)
75+
if (memMB := resource.memMB) > 0:
76+
opts["mem"] = str(memMB)
77+
if (gpu := resource.gpu) > 0:
78+
opts["gpus-per-task"] = str(gpu)
79+
80+
return cls(
81+
dir=role.image,
82+
entrypoint=role.entrypoint,
83+
args=list(role.args),
84+
opts=opts,
85+
env=dict(role.env),
86+
)
87+
88+
def materialize(self) -> str:
89+
sbatch_opts = [f"#SBATCH --{key}={value}" for key, value in self.opts.items()]
90+
sbatch_opts += [
91+
f"#SBATCH --export={key}={value}" for key, value in self.env.items()
92+
]
93+
sbatch_opts_str = "\n".join(sbatch_opts)
94+
95+
escaped_args = [_slurm_escape(arg) for arg in self.args]
96+
97+
return f"""#!/bin/sh
98+
{sbatch_opts_str}
99+
100+
# exit on error
101+
set -e
102+
103+
srun --chdir={self.dir} {self.entrypoint} {" ".join(escaped_args)}
104+
"""
105+
106+
107+
@dataclass
108+
class SlurmBatchRequest:
109+
"""
110+
Holds parameters used to launch a slurm job via sbatch.
111+
"""
112+
113+
cmd: List[str]
114+
replicas: Dict[str, SlurmReplicaRequest]
115+
116+
117+
class SlurmScheduler(Scheduler):
118+
"""
119+
SlurmScheduler is a TorchX scheduling interface to slurm. TorchX expects
120+
that slurm CLI tools are locally installed and job accounting is enabled.
121+
122+
Each app def is scheduled using a heterogenous job via sbatch.
123+
Each replica of each role has a unique shell script generated with it's
124+
resource allocations and args and then sbatch is used to launch all of them
125+
together.
126+
127+
Logs are written to the default slurm log file.
128+
129+
Any scheduler options passed to it are added as SBATCH arguments to each replica.
130+
131+
For more info see:
132+
133+
* https://slurm.schedmd.com/sbatch.html
134+
* https://slurm.schedmd.com/heterogeneous_jobs.html
135+
136+
.. code-block:: bash
137+
138+
$ torchx run --scheduler slurm utils.echo --msg hello
139+
slurm://torchx_user/1234
140+
$ torchx status slurm://torchx_user/1234
141+
$ less slurm-1234.out
142+
...
143+
"""
144+
145+
def __init__(self, session_name: str) -> None:
146+
super().__init__("slurm", session_name)
147+
148+
def schedule(self, dryrun_info: AppDryRunInfo[SlurmBatchRequest]) -> str:
149+
req = dryrun_info.request
150+
with tempfile.TemporaryDirectory() as tmpdir:
151+
for i, (name, body) in enumerate(req.replicas.items()):
152+
path = os.path.join(tmpdir, name)
153+
with open(path, "w") as f:
154+
f.write(body.materialize())
155+
156+
if i > 0:
157+
req.cmd.append(":")
158+
req.cmd.append(path)
159+
160+
p = subprocess.run(req.cmd, stdout=subprocess.PIPE, check=True)
161+
return p.stdout.decode("utf-8").strip()
162+
163+
def _submit_dryrun(
164+
self, app: AppDef, cfg: RunConfig
165+
) -> AppDryRunInfo[SlurmBatchRequest]:
166+
cmd = ["sbatch", "--parsable", "--job-name", app.name]
167+
replicas = {}
168+
for i, role in enumerate(app.roles):
169+
for replica_id in range(role.num_replicas):
170+
values = macros.Values(
171+
img_root=role.image,
172+
app_id=macros.app_id,
173+
replica_id=str(replica_id),
174+
)
175+
name = f"role-{i}-{role.name}-{replica_id}.sh"
176+
replica_role = values.apply(role)
177+
replicas[name] = SlurmReplicaRequest.from_role(replica_role, cfg)
178+
req = SlurmBatchRequest(
179+
cmd=cmd,
180+
replicas=replicas,
181+
)
182+
return AppDryRunInfo(req, repr)
183+
184+
def _validate(self, app: AppDef, scheduler: SchedulerBackend) -> None:
185+
# Skip validation step for slurm
186+
pass
187+
188+
def _cancel_existing(self, app_id: str) -> None:
189+
subprocess.run(["scancel", app_id], check=True)
190+
191+
def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
192+
p = subprocess.run(
193+
["sacct", "--parsable2", "-j", app_id], stdout=subprocess.PIPE, check=True
194+
)
195+
output = p.stdout.decode("utf-8").split("\n")
196+
if len(output) <= 1:
197+
return None
198+
199+
reader = csv.DictReader(output, delimiter="|")
200+
201+
resp = DescribeAppResponse(
202+
app_id=app_id,
203+
)
204+
for row in reader:
205+
if row["JobID"] == app_id:
206+
state = row["State"]
207+
resp.msg = state
208+
state_enum = SLURM_STATES.get(state)
209+
assert (
210+
state_enum
211+
), f"failed to translate slurm state {state} to torchx state"
212+
resp.state = state_enum
213+
214+
return resp
215+
216+
217+
def create_scheduler(session_name: str, **kwargs: Any) -> SlurmScheduler:
218+
return SlurmScheduler(
219+
session_name=session_name,
220+
)

0 commit comments

Comments
 (0)