Skip to content

Commit a4b0023

Browse files
henrylhtsangpytorchmergebot
authored andcommitted
[cutlass backend] Cache config generation locally and remotely (pytorch#154686)
Summary: Trying to cache the json list of configs. There are probably some more work: * preset * filelock (?) * for cases where we generate from scratch, save it to local as well (?) Test Plan: tested offline Reviewed By: coconutruben Differential Revision: D75334439 Pull Request resolved: pytorch#154686 Approved by: https://github.com/coconutruben, https://github.com/ColinPeppler
1 parent ba51f48 commit a4b0023

File tree

3 files changed

+104
-2
lines changed

3 files changed

+104
-2
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# mypy: allow-untyped-defs
2+
import functools
3+
import hashlib
4+
import json
5+
import logging
6+
import os
7+
import time
8+
from typing import Any, Optional
9+
10+
import torch._inductor.config as config
11+
from torch._inductor.codecache import cutlass_key
12+
from torch._inductor.codegen.cuda.cuda_env import get_cuda_arch, get_cuda_version
13+
from torch._inductor.codegen.cuda.serialization import get_cutlass_operation_serializer
14+
from torch._inductor.runtime.cache_dir_utils import cache_dir
15+
from torch._inductor.utils import clear_on_fresh_inductor_cache
16+
17+
18+
log = logging.getLogger(__name__)
19+
20+
21+
CONFIG_PREFIX: str = "configs"
22+
23+
24+
def get_config_request_key(
25+
arch: str,
26+
cuda_version: str,
27+
instantiation_level: str,
28+
) -> str:
29+
"""
30+
Return a key for the full ops, based on cutlass key, arch, cuda version, and instantiation level.
31+
"""
32+
hash_target = "-".join(
33+
[
34+
cutlass_key().decode(),
35+
arch,
36+
cuda_version,
37+
instantiation_level,
38+
]
39+
)
40+
return hashlib.sha256(hash_target.encode("utf-8")).hexdigest()[0:8]
41+
42+
43+
def _generate_config_filename(request_key: str) -> str:
44+
"""
45+
Generate a filename for the full ops.
46+
"""
47+
return f"{CONFIG_PREFIX}_{request_key}.json"
48+
49+
50+
@clear_on_fresh_inductor_cache
51+
@functools.lru_cache(None)
52+
def maybe_fetch_ops() -> Optional[list[Any]]:
53+
"""
54+
Fetch ops from databases.
55+
"""
56+
if config.force_disable_caches:
57+
return None
58+
59+
# setup
60+
arch: str = get_cuda_arch()
61+
# get_cuda_version might return "12.4.0" or "12.4"
62+
# but we want to use "12.4"
63+
version: str = ".".join(get_cuda_version().split(".")[:2])
64+
instantiation_level: str = config.cuda.cutlass_instantiation_level
65+
66+
# filename and filepath
67+
request_key: str = get_config_request_key(arch, version, instantiation_level)
68+
filename: str = _generate_config_filename(request_key)
69+
filepath: str = os.path.join(cache_dir(), filename)
70+
71+
# try fetch
72+
serialized_ops: Optional[list[str]] = None
73+
start_time = time.time()
74+
if os.path.isfile(filepath):
75+
# locally
76+
with open(filepath) as f:
77+
serialized_ops = json.load(f)
78+
elif config.is_fbcode():
79+
from torch._inductor.fb.cutlass_remote_cache import (
80+
maybe_fetch_cutlass_configs_from_remote,
81+
)
82+
83+
# from remote
84+
serialized_ops = maybe_fetch_cutlass_configs_from_remote(filepath)
85+
86+
if serialized_ops is None:
87+
return None
88+
89+
# deserialize
90+
serializer = get_cutlass_operation_serializer()
91+
full_ops = [serializer.deserialize(x) for x in serialized_ops] # type: ignore[union-attr]
92+
log.info("Loaded ops from %s cache in %.3fs", filename, time.time() - start_time)
93+
return full_ops

torch/_inductor/codegen/cuda/gemm_template.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212
import torch.utils._pytree as pytree
13+
from torch._inductor.codegen.cuda.cutlass_cache import maybe_fetch_ops
1314
from torch._inductor.scheduler import BaseSchedulerNode
1415
from torch._inductor.select_algorithm import create_inputs_key
1516
from torch._inductor.utils import clear_on_fresh_inductor_cache
@@ -930,8 +931,14 @@ def gen_ops(self) -> "list[tuple[str, cutlass_gemm_op.GemmOperation]]": # type:
930931
log.debug("Using cached ops for %s", self.cache_key)
931932
return self.filtered_ops_cache[self.cache_key]
932933

933-
full_ops = cutlass_utils.gen_ops()
934-
ops = pytree.tree_flatten(full_ops)[0]
934+
maybe_ops = maybe_fetch_ops()
935+
if maybe_ops is None:
936+
log.debug("Cannot fetch ops from cache, generating ops from scratch")
937+
full_ops = cutlass_utils.gen_ops()
938+
ops = pytree.tree_flatten(full_ops)[0]
939+
else:
940+
log.debug("Using cached ops from cache")
941+
ops = maybe_ops
935942

936943
res: dict[str, cutlass_gemm_op.GemmOperation] = {}
937944
start_time = time.time()

torch/_inductor/codegen/cuda/serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mypy: allow-untyped-defs
22
import enum
3+
import functools
34
import json
45
from enum import Enum
56
from typing import Optional
@@ -458,6 +459,7 @@ def _json_to_enum(cls, json_dict, enum_class):
458459
return enum_class[json_dict["name"]]
459460

460461

462+
@functools.lru_cache(1)
461463
def get_cutlass_operation_serializer() -> Optional[CUTLASSOperationSerializer]:
462464
if not try_import_cutlass():
463465
return None

0 commit comments

Comments
 (0)