Skip to content

Commit 159b7ad

Browse files
aorenstepytorchmergebot
authored andcommitted
Improve async workers to handle forking for async compile (pytorch#142072)
Pull Request resolved: pytorch#142072 Approved by: https://github.com/masnesral
1 parent 678f749 commit 159b7ad

File tree

3 files changed

+129
-35
lines changed

3 files changed

+129
-35
lines changed

torch/_inductor/async_compile.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: allow-untyped-defs
22
from __future__ import annotations
33

4+
import atexit
45
import functools
56
import logging
67
import os
@@ -327,3 +328,10 @@ def wait(self, scope: Dict[str, Any]) -> None:
327328
pass
328329
else:
329330
AsyncCompile.warm_pool()
331+
332+
# On exit give the workers a chance to clean themselves up. Without this the
333+
# resource_tracker can complain about leaked semaphores coming from the
334+
# ProcessPoolExecutor:
335+
# UserWarning: resource_tracker: There appear to be 5 leaked semaphore objects
336+
# to clean up at shutdown
337+
atexit.register(shutdown_compile_workers)

torch/_inductor/compile_worker/__main__.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
# mypy: allow-untyped-defs
22
import argparse
3+
import functools
4+
import importlib
35
import logging
46
import os
57
import sys
8+
from typing import Type, TypeVar
69

710
from torch._inductor.async_compile import pre_fork_setup
8-
from torch._inductor.compile_worker.subproc_pool import SubprocMain
11+
from torch._inductor.compile_worker.subproc_pool import (
12+
SubprocKind,
13+
SubprocMain,
14+
SubprocPickler,
15+
)
916
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
1017
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
1118

1219

20+
_T = TypeVar("_T")
21+
22+
1323
log = logging.getLogger(__name__)
1424

1525
_set_triton_ptxas_path()
@@ -22,9 +32,26 @@
2232
pass
2333

2434

35+
def _lookup_and_create_type(base: Type[_T], qname: str) -> _T:
36+
"""
37+
Given a base type and qualified name: import & lookup that name, check
38+
that it's of the given type and then instantiate it.
39+
"""
40+
pkg, name = qname.rsplit(".", 1)
41+
mod = importlib.import_module(pkg)
42+
ty = getattr(mod, name)
43+
if not issubclass(ty, base):
44+
raise TypeError(f"Type {ty} is not a subtype of {base}")
45+
return ty()
46+
47+
2548
def main():
2649
try:
2750
parser = argparse.ArgumentParser()
51+
parser.add_argument(
52+
"--pickler", type=functools.partial(_lookup_and_create_type, SubprocPickler)
53+
)
54+
parser.add_argument("--kind", type=SubprocKind)
2855
parser.add_argument("--workers", type=int)
2956
parser.add_argument("--parent", type=int)
3057
parser.add_argument("--read-fd", type=int)
@@ -38,7 +65,8 @@ def main():
3865
pre_fork_setup()
3966

4067
_async_compile_initializer(args.parent)
41-
SubprocMain(args.workers, read_fd, write_fd).main()
68+
69+
SubprocMain(args.pickler, args.kind, args.workers, read_fd, write_fd).main()
4270
except Exception:
4371
log.exception("Uncaught exception in compile_worker subprocess")
4472

torch/_inductor/compile_worker/subproc_pool.py

Lines changed: 91 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99
import sys
1010
import threading
1111
import traceback
12+
import typing
1213
from concurrent.futures import Future, ProcessPoolExecutor
1314
from concurrent.futures.process import BrokenProcessPool
14-
from typing import Any, BinaryIO, Callable, Dict, Tuple, TypeVar
15+
from enum import Enum
16+
from typing import Any, BinaryIO, Callable, Dict, Optional, Tuple, TypeVar
1517
from typing_extensions import Never, ParamSpec
1618

1719
# _thread_safe_fork is needed because the subprocesses in the pool can read
@@ -88,14 +90,39 @@ def __init__(self, details: str) -> None:
8890
super().__init__(f"An exception occurred in a subprocess:\n\n{details}")
8991

9092

93+
class SubprocPickler:
94+
"""
95+
Allows a caller to provide a custom pickler for passing data with the
96+
subprocess.
97+
"""
98+
99+
def dumps(self, obj: object) -> bytes:
100+
return pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
101+
102+
def loads(self, data: bytes) -> object:
103+
return pickle.loads(data)
104+
105+
106+
class SubprocKind(Enum):
107+
FORK = "fork"
108+
SPAWN = "spawn"
109+
110+
91111
class SubprocPool:
92112
"""
93113
Mimic a concurrent.futures.ProcessPoolExecutor, but wrap it in
94114
a subprocess.Popen() to try to avoid issues with forking/spawning
95115
"""
96116

97-
def __init__(self, nprocs: int) -> None:
117+
def __init__(
118+
self,
119+
nprocs: int,
120+
pickler: Optional[SubprocPickler] = None,
121+
kind: SubprocKind = SubprocKind.FORK,
122+
) -> None:
98123
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
124+
self.pickler = pickler or SubprocPickler()
125+
self.kind = kind
99126

100127
subproc_read_fd, write_fd = os.pipe()
101128
read_fd, subproc_write_fd = os.pipe()
@@ -105,6 +132,8 @@ def __init__(self, nprocs: int) -> None:
105132
cmd = [
106133
sys.executable,
107134
entry,
135+
f"--pickler={self.pickler.__class__.__module__}.{self.pickler.__class__.__name__}",
136+
f"--kind={self.kind.value}",
108137
f"--workers={nprocs}",
109138
f"--parent={os.getpid()}",
110139
f"--read-fd={str(subproc_read_fd)}",
@@ -143,7 +172,7 @@ def submit(
143172
) -> Future[_T]:
144173
if args or kwargs:
145174
job_fn = functools.partial(job_fn, *args, **kwargs)
146-
job_data = pickle.dumps(job_fn, pickle.HIGHEST_PROTOCOL)
175+
job_data = self.pickler.dumps(job_fn)
147176
future: Future[_T]
148177
with self.futures_lock:
149178
job_id = next(self.job_id_count)
@@ -156,31 +185,48 @@ def submit(
156185
return future
157186

158187
def _read_thread(self) -> None:
159-
try:
160-
while True:
188+
while True:
189+
data = b""
190+
try:
161191
job_id, data = _recv_msg(self.read_pipe)
162-
if job_id < 0:
163-
if self.running:
164-
log.warning("SubprocPool unclean exit")
165-
self.read_pipe.close()
192+
except Exception as e:
193+
# Something went wrong during the read. There's no way we have a
194+
# valid job_id.
195+
log.exception("failure in subproc_pool._recv_msg")
196+
job_id = -1
197+
198+
if job_id < 0:
199+
# read_pipe returned None or got exception
200+
if self.running:
201+
log.warning("SubprocPool unclean exit")
202+
self.running = False
203+
self.read_pipe.close()
204+
# Cancel all the pending futures.
205+
self.shutdown()
206+
return
207+
208+
try:
209+
result = self.pickler.loads(data)
210+
except Exception as e:
211+
# Something went wrong unpickling. We have a job_id so just
212+
# notify that particular future and continue on.
213+
log.exception("unpickle failure in SubprocPool._read_thread")
214+
result = e
215+
216+
with self.futures_lock:
217+
if not self.running:
166218
return
167-
result = pickle.loads(data)
168-
with self.futures_lock:
169-
if not self.running:
170-
return
171-
if isinstance(result, _SubprocExceptionInfo):
172-
# An exception occurred in the submitted job
173-
self.pending_futures[job_id].set_exception(
174-
SubprocException(result.details)
175-
)
176-
elif isinstance(result, Exception):
177-
# An exception occurred in some of our subprocess machinery.
178-
self.pending_futures[job_id].set_exception(result)
179-
else:
180-
self.pending_futures[job_id].set_result(result)
181-
del self.pending_futures[job_id]
182-
except Exception:
183-
log.exception("failure in SubprocPool._read_thread")
219+
if isinstance(result, _SubprocExceptionInfo):
220+
# An exception occurred in the submitted job
221+
self.pending_futures[job_id].set_exception(
222+
SubprocException(result.details)
223+
)
224+
elif isinstance(result, Exception):
225+
# An exception occurred in some of our subprocess machinery.
226+
self.pending_futures[job_id].set_exception(result)
227+
else:
228+
self.pending_futures[job_id].set_result(result)
229+
del self.pending_futures[job_id]
184230

185231
def shutdown(self) -> None:
186232
try:
@@ -204,7 +250,16 @@ def shutdown(self) -> None:
204250
class SubprocMain:
205251
"""Communicates with a SubprocPool in the parent process, called by __main__.py"""
206252

207-
def __init__(self, nprocs: int, read_pipe: BinaryIO, write_pipe: BinaryIO) -> None:
253+
def __init__(
254+
self,
255+
pickler: SubprocPickler,
256+
kind: SubprocKind,
257+
nprocs: int,
258+
read_pipe: BinaryIO,
259+
write_pipe: BinaryIO,
260+
) -> None:
261+
self.pickler = pickler
262+
self.kind = kind
208263
self.read_pipe = read_pipe
209264
self.write_pipe = write_pipe
210265
self.write_lock = threading.Lock()
@@ -215,7 +270,7 @@ def __init__(self, nprocs: int, read_pipe: BinaryIO, write_pipe: BinaryIO) -> No
215270
def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor:
216271
pool = ProcessPoolExecutor(
217272
nprocs,
218-
mp_context=multiprocessing.get_context("fork"),
273+
mp_context=multiprocessing.get_context(self.kind.value),
219274
initializer=functools.partial(_async_compile_initializer, os.getpid()),
220275
)
221276
multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
@@ -253,7 +308,9 @@ def submit(self, job_id: int, data: bytes) -> None:
253308
self.pool = self._new_pool(self.nprocs, False)
254309

255310
def _submit_inner(self, job_id: int, data: bytes) -> None:
256-
future = self.pool.submit(functools.partial(SubprocMain.do_job, data))
311+
future = self.pool.submit(
312+
functools.partial(SubprocMain.do_job, self.pickler, data)
313+
)
257314

258315
def callback(_: Future[Any]) -> None:
259316
if not self.running:
@@ -262,7 +319,7 @@ def callback(_: Future[Any]) -> None:
262319
result = future.result()
263320
except Exception as e:
264321
log.exception("Error in subprocess")
265-
result = pickle.dumps(e, pickle.HIGHEST_PROTOCOL)
322+
result = self.pickler.dumps(e)
266323
assert isinstance(result, bytes)
267324
with self.write_lock:
268325
if self.running:
@@ -272,14 +329,15 @@ def callback(_: Future[Any]) -> None:
272329
future.add_done_callback(callback)
273330

274331
@staticmethod
275-
def do_job(data: bytes) -> bytes:
332+
def do_job(pickler: SubprocPickler, data: bytes) -> bytes:
276333
# do the pickle/unpickle in the sub-subproc
277-
job = pickle.loads(data)
334+
job = typing.cast(Callable[[], object], pickler.loads(data))
335+
278336
try:
279337
result = job()
280338
except Exception:
281339
result = _SubprocExceptionInfo(traceback.format_exc())
282-
return pickle.dumps(result, pickle.HIGHEST_PROTOCOL)
340+
return pickler.dumps(result)
283341

284342

285343
def _warm_process_pool(pool: ProcessPoolExecutor, n: int) -> None:

0 commit comments

Comments
 (0)