Skip to content

Commit 88a412e

Browse files
[torch.compile] fast inductor (#11108)
Signed-off-by: youkaichao <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]>
1 parent c301616 commit 88a412e

File tree

3 files changed

+624
-7
lines changed

3 files changed

+624
-7
lines changed

vllm/compilation/backends.py

Lines changed: 210 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import ast
12
import copy
23
import dataclasses
4+
import os
5+
import pprint
36
import time
7+
from collections import defaultdict
48
from contextlib import ExitStack
59
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple
610
from unittest.mock import patch
@@ -21,6 +25,122 @@
2125
logger = init_logger(__name__)
2226

2327

28+
class InductorHashCache:
29+
"""
30+
Disk format: a Python list of tuples, each tuple is
31+
(runtime_shape, graph_index, hash_str)
32+
We use list of tuple for readability.
33+
34+
In-memory format: a defaultdict of dict, where the key is
35+
runtime_shape, and the value is a dict of graph_index to hash_str.
36+
37+
The data is essentially `Dict[Optional[int], Dict[int, str]]`,
38+
we don't use json here because json doesn't support int as key.
39+
40+
TODO: better off-the-shelf solution to serialize the data?
41+
"""
42+
43+
def __init__(self, cache_dir: str, disabled: bool = False):
44+
self.cache: defaultdict = defaultdict(dict)
45+
self.disabled = disabled
46+
self.cache_dir = cache_dir
47+
self.cache_file_path = os.path.join(cache_dir,
48+
"inductor_hash_cache.py")
49+
if disabled:
50+
return
51+
# set flags so that Inductor and Triton store their cache
52+
# in the cache_dir, then users only need to copy the cache_dir
53+
# to another machine to reuse the cache.
54+
inductor_cache = os.path.join(cache_dir, "inductor_cache")
55+
os.makedirs(inductor_cache, exist_ok=True)
56+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = inductor_cache
57+
triton_cache = os.path.join(cache_dir, "triton_cache")
58+
os.makedirs(triton_cache, exist_ok=True)
59+
os.environ["TRITON_CACHE_DIR"] = triton_cache
60+
if os.path.exists(self.cache_file_path):
61+
with open(self.cache_file_path) as f:
62+
self.deserialize(f.read())
63+
64+
def deserialize(self, data: str):
65+
# we use ast.literal_eval to parse the data
66+
# because it is a safe way to parse Python literals.
67+
# do not use eval(), it is unsafe.
68+
list_data = ast.literal_eval(data)
69+
for runtime_shape, graph_index, hash_str in list_data:
70+
self.cache[runtime_shape][graph_index] = hash_str
71+
72+
def serialize(self) -> str:
73+
data = []
74+
for runtime_shape, graph_index_to_hash_str in self.cache.items():
75+
for graph_index, hash_str in graph_index_to_hash_str.items():
76+
data.append((runtime_shape, graph_index, hash_str))
77+
printer = pprint.PrettyPrinter(indent=4)
78+
return printer.pformat(data)
79+
80+
def save_to_file(self):
81+
if self.disabled:
82+
return
83+
with open(self.cache_file_path, "w") as f:
84+
f.write(self.serialize())
85+
86+
def __contains__(self, key: Tuple[Optional[int], int]) -> bool:
87+
if self.disabled:
88+
return False
89+
runtime_shape, graph_index = key
90+
return runtime_shape in self.cache and graph_index in self.cache[
91+
runtime_shape]
92+
93+
def __getitem__(self, key: Tuple[Optional[int], int]) -> str:
94+
if self.disabled:
95+
raise KeyError("cannot read from disabled cache")
96+
runtime_shape, graph_index = key
97+
return self.cache[runtime_shape][graph_index]
98+
99+
def __setitem__(self, key: Tuple[Optional[int], int], value: str):
100+
# setitem for disabled cache is fine, because we
101+
# don't actually write to the disk
102+
runtime_shape, graph_index = key
103+
self.cache[runtime_shape][graph_index] = value
104+
105+
106+
class AlwaysHitShapeEnv:
107+
"""
108+
Why do we need this class:
109+
110+
For normal `torch.compile` usage, every compilation will have
111+
one Dynamo bytecode compilation and one Inductor compilation.
112+
The Inductor compilation happens under the context of the
113+
Dynamo bytecode compilation, and that context is used to
114+
determine the dynamic shape information, etc.
115+
116+
For our use case, we only run Dynamo bytecode compilation once,
117+
and run Inductor compilation multiple times with different shapes
118+
plus a general shape. The compilation for specific shapes happens
119+
outside of the context of the Dynamo bytecode compilation. At that
120+
time, we don't have shape environment to provide to Inductor, and
121+
it will fail the Inductor code cache lookup.
122+
123+
By providing a dummy shape environment that always hits, we can
124+
make the Inductor code cache lookup always hit, and we can
125+
compile the graph for different shapes as needed.
126+
127+
The following dummy methods are obtained by trial-and-error
128+
until it works.
129+
"""
130+
131+
def __init__(self) -> None:
132+
self.guards: List[Any] = []
133+
134+
def evaluate_guards_expression(self, *args, **kwargs):
135+
return True
136+
137+
def get_pruned_guards(self, *args, **kwargs):
138+
return []
139+
140+
def produce_guards_expression(self, *args, **kwargs):
141+
return ""
142+
143+
24144
def wrap_inductor(graph,
25145
example_inputs,
26146
additional_inductor_config,
@@ -55,9 +175,93 @@ def wrap_inductor(graph,
55175
# inductor can inplace modify the graph, so we need to copy it
56176
# see https://github.com/pytorch/pytorch/issues/138980
57177
graph = copy.deepcopy(graph)
58-
compiled_graph = compile_fx(graph,
59-
example_inputs,
60-
config_patches=current_config)
178+
179+
cache_data = compilation_config.inductor_hash_cache
180+
if (runtime_shape, graph_index) in cache_data:
181+
# we compiled this graph before
182+
# so we can directly lookup the compiled graph via hash
183+
hash_str = cache_data[(runtime_shape, graph_index)]
184+
if graph_index == 0:
185+
# adds some info logging for the first graph
186+
logger.info(
187+
"Directly lookup the graph for shape %s from the cache",
188+
str(runtime_shape)) # noqa
189+
logger.debug(
190+
"directly lookup the %s-th graph for shape %s via hash %s",
191+
graph_index, str(runtime_shape), hash_str)
192+
from torch._inductor.codecache import FxGraphCache
193+
with patch("torch._inductor.codecache.FxGraphCache._get_shape_env",
194+
lambda *args, **kwargs: AlwaysHitShapeEnv()):
195+
inductor_compiled_graph = FxGraphCache._lookup_graph(
196+
hash_str, example_inputs, True, False)
197+
assert inductor_compiled_graph is not None, (
198+
"Inductor cache lookup failed. Please remove"
199+
f"the cache file {compilation_config.inductor_hash_cache.cache_file_path} and try again." # noqa
200+
)
201+
202+
# Inductor calling convention (function signature):
203+
# f(list) -> tuple
204+
# Dynamo calling convention (function signature):
205+
# f(*args) -> Any
206+
207+
# need to know if the graph returns a tuple
208+
from torch._inductor.compile_fx import graph_returns_tuple
209+
returns_tuple = graph_returns_tuple(graph)
210+
211+
# this is the graph we return to Dynamo to run
212+
def compiled_graph(*args):
213+
# convert args to list
214+
list_args = list(args)
215+
graph_output = inductor_compiled_graph(list_args)
216+
# unpack the tuple if needed
217+
if returns_tuple:
218+
return graph_output
219+
else:
220+
return graph_output[0]
221+
else:
222+
# it's the first time we compile this graph
223+
# the assumption is that we don't have nested Inductor compilation.
224+
# compiled_fx_graph_hash will only be called once, and we can hook
225+
# it to get the hash of the compiled graph directly.
226+
from torch._inductor.codecache import compiled_fx_graph_hash
227+
228+
def hijack_compiled_fx_graph_hash(*args, **kwargs):
229+
out = compiled_fx_graph_hash(*args, **kwargs)
230+
# store the hash in the cache
231+
nonlocal cache_data
232+
cache_data[(runtime_shape, graph_index)] = out[0]
233+
if graph_index == 0:
234+
# adds some info logging for the first graph
235+
logger.info("Cache the graph of shape %s for later use",
236+
str(runtime_shape))
237+
logger.debug("store the %s-th graph for shape %s via hash %s",
238+
graph_index, str(runtime_shape), out[0])
239+
return out
240+
241+
def _check_can_cache(*args, **kwargs):
242+
# no error means it can be cached.
243+
# Inductor refuses to cache the graph outside of Dynamo
244+
# tracing context, and also disables caching for graphs
245+
# with high-order ops.
246+
# For vLLM, in either case, we want to cache the graph.
247+
# see https://github.com/pytorch/pytorch/blob/9f5ebf3fc609105a74eab4ccc24932d6353ff566/torch/_inductor/codecache.py#L1221 # noqa
248+
return
249+
250+
def _get_shape_env():
251+
return AlwaysHitShapeEnv()
252+
253+
with patch(# for hijacking the hash of the compiled graph
254+
"torch._inductor.codecache.compiled_fx_graph_hash",
255+
hijack_compiled_fx_graph_hash), \
256+
patch(# for providing a dummy shape environment
257+
"torch._inductor.codecache.FxGraphCache._get_shape_env",
258+
_get_shape_env), \
259+
patch(# for forcing the graph to be cached
260+
"torch._inductor.codecache.FxGraphCache._check_can_cache",
261+
_check_can_cache):
262+
compiled_graph = compile_fx(graph,
263+
example_inputs,
264+
config_patches=current_config)
61265

62266
# after compiling the last graph, record the end time
63267
if graph_index == num_graphs - 1:
@@ -457,6 +661,9 @@ def __call__(self, *args) -> Any:
457661

458662
# finished compilations for all required shapes
459663
if self.is_last_graph and not self.to_be_compiled_sizes:
664+
665+
# save the hash of the inductor graph for the next run
666+
self.compilation_config.inductor_hash_cache.save_to_file()
460667
end_monitoring_torch_compile(self.vllm_config)
461668

462669
if not entry.use_cudagraph:

0 commit comments

Comments
 (0)