Skip to content

Commit 58cc669

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[BE] Type annotate wrapper_benchmark.py and cuda_combined_scheduling.py (pytorch#145542)
Pull Request resolved: pytorch#145542 Approved by: https://github.com/eellison
1 parent 8cc6f17 commit 58cc669

File tree

4 files changed

+83
-37
lines changed

4 files changed

+83
-37
lines changed

torch/_inductor/codegen/cuda_combined_scheduling.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# mypy: allow-untyped-defs
22
from __future__ import annotations
33

4-
from typing import Optional, TYPE_CHECKING, Union
4+
from typing import Any, List, Optional, Tuple, TYPE_CHECKING, Union
55

66
from ..scheduler import (
77
BaseSchedulerNode,
@@ -17,12 +17,17 @@
1717

1818
if TYPE_CHECKING:
1919
from collections.abc import Sequence
20+
from typing_extensions import TypeAlias
21+
22+
from sympy import Expr
2023

2124
import torch
2225
from torch.utils._ordered_set import OrderedSet
2326

2427
from .common import BackendFeature
2528

29+
_IntLike: TypeAlias = Union[int, Expr]
30+
2631

2732
class CUDACombinedScheduling(BaseScheduling):
2833
"""
@@ -67,15 +72,17 @@ def can_fuse_horizontal(
6772
) # always False at the moment
6873
return self._triton_scheduling.can_fuse_horizontal(node1, node2)
6974

70-
def group_fn(self, sizes):
75+
def group_fn(
76+
self, sizes: Sequence[Sequence[_IntLike]]
77+
) -> tuple[tuple[_IntLike, ...], ...]:
7178
return self._triton_scheduling.group_fn(sizes)
7279

7380
def codegen_template(
7481
self,
7582
template_node: BaseSchedulerNode,
7683
epilogue_nodes: Sequence[BaseSchedulerNode],
7784
prologue_nodes: Sequence[BaseSchedulerNode],
78-
):
85+
) -> Optional[str]:
7986
if self._cuda_cpp_scheduling.is_cuda_cpp_template(template_node):
8087
assert not epilogue_nodes
8188
assert not prologue_nodes
@@ -93,28 +100,34 @@ def codegen_template(
93100
template_node, epilogue_nodes, prologue_nodes
94101
)
95102

96-
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]):
103+
def codegen_node(self, node: Union[FusedSchedulerNode, SchedulerNode]) -> None:
97104
return self._triton_scheduling.codegen_node(node)
98105

99-
def codegen_sync(self):
106+
def codegen_sync(self) -> None:
100107
return self._triton_scheduling.codegen_sync()
101108

102-
def flush(self):
109+
def flush(self) -> None:
103110
return self._triton_scheduling.flush()
104111

105-
def codegen_combo_kernel(self, *args, **kwargs):
112+
def codegen_combo_kernel(self, *args: Any, **kwargs: Any) -> None:
106113
return self._triton_scheduling.codegen_combo_kernel(*args, **kwargs)
107114

108-
def benchmark_fused_nodes(self, nodes):
115+
def benchmark_fused_nodes(
116+
self, nodes: Sequence[BaseSchedulerNode]
117+
) -> Tuple[float, str]:
109118
return self._triton_scheduling.benchmark_fused_nodes(nodes)
110119

111120
def benchmark_codegened_module(self, module):
112121
return self._triton_scheduling.benchmark_codegened_module(module)
113122

114-
def generate_kernel_code_from_nodes(self, nodes, benchmark_kernel=False):
123+
def generate_kernel_code_from_nodes(
124+
self, nodes: Sequence[Any], benchmark_kernel: bool = False
125+
) -> str:
115126
return self._triton_scheduling.generate_kernel_code_from_nodes(
116127
nodes, benchmark_kernel
117128
)
118129

119-
def benchmark_combo_kernel(self, node_list):
130+
def benchmark_combo_kernel(
131+
self, node_list: Sequence[BaseSchedulerNode]
132+
) -> tuple[float, float, List[Optional[str]]]:
120133
return self._triton_scheduling.benchmark_combo_kernel(node_list)

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4137,8 +4137,8 @@ def add_multi_kernel_choices(
41374137

41384138
def benchmark_combo_kernel(self, node_list):
41394139
mod: ModuleType
4140-
ms: int
4141-
ms_clone: int
4140+
ms: float
4141+
ms_clone: float
41424142

41434143
def cache_file_path():
41444144
assert mod.__file__ is not None
@@ -4157,7 +4157,7 @@ def store_cache():
41574157
fd.write(str(ms) + " " + str(ms_clone))
41584158

41594159
total_ms, file_list = 0, []
4160-
total_clone_ms = 0
4160+
total_clone_ms: float = 0.0
41614161
removed_buffers_orig = V.graph.removed_buffers
41624162
V.graph.removed_buffers = OrderedSet(removed_buffers_orig)
41634163
inplaced_to_remove_orig = V.graph.inplaced_to_remove
@@ -4186,7 +4186,7 @@ def store_cache():
41864186
)
41874187
ms, ms_clone = load_cache()
41884188
if ms is not None:
4189-
total_ms += ms
4189+
total_ms += ms # type: ignore[assignment]
41904190
total_clone_ms += ms_clone
41914191
file_list.append(mod.__file__)
41924192
continue

torch/_inductor/scheduler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3988,7 +3988,7 @@ def _codegen(self) -> None:
39883988

39893989
def benchmark_combo_kernel(
39903990
self, node_list: Sequence[BaseSchedulerNode]
3991-
) -> tuple[float, float, str]:
3991+
) -> tuple[float, float, List[Optional[str]]]:
39923992
"""
39933993
Benchmark fused list of nodes and return the execution time
39943994
in milliseconds on randomly generated inputs.
@@ -4228,7 +4228,7 @@ def get_fusion_pair_priority(
42284228

42294229
def benchmark_combo_kernel(
42304230
self, node_list: Sequence[BaseSchedulerNode]
4231-
) -> tuple[float, float, str]:
4231+
) -> tuple[float, float, List[Optional[str]]]:
42324232
"""
42334233
Benchmark the list of nodes to combine and return the execution time
42344234
and memory copy time in milliseconds on randomly generated inputs.

torch/_inductor/wrapper_benchmark.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
# mypy: allow-untyped-defs
21
import dataclasses
32
import datetime
43
import tempfile
54
from collections import defaultdict
5+
from types import ModuleType
6+
from typing import Any, Dict, Optional, Protocol
67

78
import torch
89
from torch.autograd import DeviceType
@@ -12,6 +13,11 @@
1213
from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
1314

1415

16+
class BenchmarkCallableType(Protocol):
17+
def __call__(self, times: int, repeat: int) -> float:
18+
...
19+
20+
1521
_kernel_category_choices = [
1622
"foreach",
1723
"persistent_reduction",
@@ -22,7 +28,7 @@
2228
]
2329

2430

25-
def get_kernel_category_by_source_code(src_code):
31+
def get_kernel_category_by_source_code(src_code: str) -> str:
2632
"""
2733
Similar to get_kernel_category but use the source code. Call this API
2834
if we have not compile the src_code to module yet.
@@ -36,7 +42,7 @@ def get_kernel_category_by_source_code(src_code):
3642
return "unknown"
3743

3844

39-
def get_kernel_category(kernel_mod):
45+
def get_kernel_category(kernel_mod: ModuleType) -> str:
4046
"""
4147
Given the module defining a triton kernel, return the category of the kernel.
4248
Category can be one of:
@@ -54,7 +60,7 @@ def get_kernel_category(kernel_mod):
5460
return "unknown"
5561

5662

57-
def get_triton_kernel(mod):
63+
def get_triton_kernel(mod: ModuleType): # type: ignore[no-untyped-def]
5864
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
5965

6066
cand_list = [
@@ -66,7 +72,9 @@ def get_triton_kernel(mod):
6672
return cand_list[0]
6773

6874

69-
def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
75+
def benchmark_all_kernels(
76+
benchmark_name: str, benchmark_all_configs: Optional[Dict[Any, Any]]
77+
) -> None:
7078
"""
7179
An experimental API used only when config.benchmark_kernel is true.
7280
@@ -98,7 +106,13 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
98106
if num_gb is None:
99107
num_gb = get_num_bytes(*args, num_in_out_args=num_in_out_ptrs) / 1e9
100108

101-
def get_info_str(ms, n_regs, n_spills, shared, prefix=""):
109+
def get_info_str(
110+
ms: float,
111+
n_regs: Optional[Any],
112+
n_spills: Optional[Any],
113+
shared: Optional[Any],
114+
prefix: str = "",
115+
) -> str:
102116
if not any(x is None for x in [n_regs, n_spills, shared]):
103117
kernel_detail_str = (
104118
f" {n_regs:3} regs {n_spills:3} spills {shared:8} shared mem"
@@ -156,22 +170,31 @@ class ProfileEvent:
156170

157171

158172
def parse_profile_event_list(
159-
benchmark_name, event_list, wall_time_ms, nruns, device_name
160-
):
161-
def get_self_device_time(ev):
173+
benchmark_name: str,
174+
event_list: torch.autograd.profiler_util.EventList,
175+
wall_time_ms: float,
176+
nruns: int,
177+
device_name: str,
178+
) -> None:
179+
def get_self_device_time(
180+
ev: torch.autograd.profiler_util.EventList,
181+
) -> float:
162182
"""
163183
ev.self_device_time_total is in microsecond. Convert to millisecond.
164184
"""
165-
return ev.self_device_time_total / 1000 / nruns
185+
return ev.self_device_time_total / 1000 / nruns # type: ignore[attr-defined]
166186

167-
all_events = defaultdict(list)
187+
all_events: Dict[str, list[ProfileEvent]] = defaultdict(list)
168188

169-
def add_event(ev, category):
189+
def add_event(
190+
ev: torch.autograd.profiler_util.EventList,
191+
category: str,
192+
) -> None:
170193
profile_ev = ProfileEvent(
171194
category=category,
172-
key=ev.key,
195+
key=ev.key, # type: ignore[attr-defined]
173196
self_device_time_ms=get_self_device_time(ev),
174-
count=ev.count / nruns, # average across all runs
197+
count=ev.count / nruns, # type: ignore[operator] # average across all runs
175198
)
176199
all_events[category].append(profile_ev)
177200

@@ -194,7 +217,7 @@ def add_event(ev, category):
194217

195218
add_event(ev, category)
196219

197-
def report_category(category, profile_events):
220+
def report_category(category: str, profile_events: list[ProfileEvent]) -> float:
198221
if not device_name:
199222
return 0.0
200223

@@ -225,7 +248,7 @@ def report_category(category, profile_events):
225248
)
226249
return total_time
227250

228-
def report():
251+
def report() -> None:
229252
category_list = [
230253
"triton_pointwise",
231254
"triton_reduction",
@@ -273,8 +296,12 @@ def report():
273296

274297

275298
def perf_profile(
276-
wall_time_ms, times, repeat, benchmark_name, benchmark_compiled_module_fn
277-
):
299+
wall_time_ms: float,
300+
times: int,
301+
repeat: int,
302+
benchmark_name: str,
303+
benchmark_compiled_module_fn: BenchmarkCallableType,
304+
) -> None:
278305
with torch.profiler.profile(record_shapes=True) as p:
279306
benchmark_compiled_module_fn(times=times, repeat=repeat)
280307

@@ -289,7 +316,9 @@ def perf_profile(
289316
)
290317

291318

292-
def ncu_analyzer(benchmark_name, benchmark_compiled_module_fn):
319+
def ncu_analyzer(
320+
benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType
321+
) -> None:
293322
import inspect
294323
import os
295324
import subprocess
@@ -339,7 +368,9 @@ def ncu_analyzer(benchmark_name, benchmark_compiled_module_fn):
339368
return
340369

341370

342-
def collect_memory_snapshot(benchmark_compiled_module_fn):
371+
def collect_memory_snapshot(
372+
benchmark_compiled_module_fn: BenchmarkCallableType,
373+
) -> None:
343374
assert torch.cuda.is_available()
344375

345376
torch.cuda.memory._record_memory_history(max_entries=100000)
@@ -350,7 +381,9 @@ def collect_memory_snapshot(benchmark_compiled_module_fn):
350381
print(f"The collect memory snapshot has been written to {snapshot_path}")
351382

352383

353-
def compiled_module_main(benchmark_name, benchmark_compiled_module_fn):
384+
def compiled_module_main(
385+
benchmark_name: str, benchmark_compiled_module_fn: BenchmarkCallableType
386+
) -> None:
354387
"""
355388
This is the function called in __main__ block of a compiled module.
356389
"""

0 commit comments

Comments
 (0)