1- # mypy: allow-untyped-defs
21import dataclasses
32import datetime
43import tempfile
54from collections import defaultdict
5+ from types import ModuleType
6+ from typing import Any , Dict , Optional , Protocol
67
78import torch
89from torch .autograd import DeviceType
1213from .runtime .runtime_utils import create_bandwidth_info_str , get_num_bytes
1314
1415
16+ class BenchmarkCallableType (Protocol ):
17+ def __call__ (self , times : int , repeat : int ) -> float :
18+ ...
19+
20+
1521_kernel_category_choices = [
1622 "foreach" ,
1723 "persistent_reduction" ,
2228]
2329
2430
25- def get_kernel_category_by_source_code (src_code ) :
31+ def get_kernel_category_by_source_code (src_code : str ) -> str :
2632 """
2733 Similar to get_kernel_category but use the source code. Call this API
2834 if we have not compile the src_code to module yet.
@@ -36,7 +42,7 @@ def get_kernel_category_by_source_code(src_code):
3642 return "unknown"
3743
3844
39- def get_kernel_category (kernel_mod ) :
45+ def get_kernel_category (kernel_mod : ModuleType ) -> str :
4046 """
4147 Given the module defining a triton kernel, return the category of the kernel.
4248 Category can be one of:
@@ -54,7 +60,7 @@ def get_kernel_category(kernel_mod):
5460 return "unknown"
5561
5662
57- def get_triton_kernel (mod ):
63+ def get_triton_kernel (mod : ModuleType ): # type: ignore[no-untyped-def]
5864 from torch ._inductor .runtime .triton_heuristics import CachingAutotuner
5965
6066 cand_list = [
@@ -66,7 +72,9 @@ def get_triton_kernel(mod):
6672 return cand_list [0 ]
6773
6874
69- def benchmark_all_kernels (benchmark_name , benchmark_all_configs ):
75+ def benchmark_all_kernels (
76+ benchmark_name : str , benchmark_all_configs : Optional [Dict [Any , Any ]]
77+ ) -> None :
7078 """
7179 An experimental API used only when config.benchmark_kernel is true.
7280
@@ -98,7 +106,13 @@ def benchmark_all_kernels(benchmark_name, benchmark_all_configs):
98106 if num_gb is None :
99107 num_gb = get_num_bytes (* args , num_in_out_args = num_in_out_ptrs ) / 1e9
100108
101- def get_info_str (ms , n_regs , n_spills , shared , prefix = "" ):
109+ def get_info_str (
110+ ms : float ,
111+ n_regs : Optional [Any ],
112+ n_spills : Optional [Any ],
113+ shared : Optional [Any ],
114+ prefix : str = "" ,
115+ ) -> str :
102116 if not any (x is None for x in [n_regs , n_spills , shared ]):
103117 kernel_detail_str = (
104118 f" { n_regs :3} regs { n_spills :3} spills { shared :8} shared mem"
@@ -156,22 +170,31 @@ class ProfileEvent:
156170
157171
158172def parse_profile_event_list (
159- benchmark_name , event_list , wall_time_ms , nruns , device_name
160- ):
161- def get_self_device_time (ev ):
173+ benchmark_name : str ,
174+ event_list : torch .autograd .profiler_util .EventList ,
175+ wall_time_ms : float ,
176+ nruns : int ,
177+ device_name : str ,
178+ ) -> None :
179+ def get_self_device_time (
180+ ev : torch .autograd .profiler_util .EventList ,
181+ ) -> float :
162182 """
163183 ev.self_device_time_total is in microsecond. Convert to millisecond.
164184 """
165- return ev .self_device_time_total / 1000 / nruns
185+ return ev .self_device_time_total / 1000 / nruns # type: ignore[attr-defined]
166186
167- all_events = defaultdict (list )
187+ all_events : Dict [ str , list [ ProfileEvent ]] = defaultdict (list )
168188
169- def add_event (ev , category ):
189+ def add_event (
190+ ev : torch .autograd .profiler_util .EventList ,
191+ category : str ,
192+ ) -> None :
170193 profile_ev = ProfileEvent (
171194 category = category ,
172- key = ev .key ,
195+ key = ev .key , # type: ignore[attr-defined]
173196 self_device_time_ms = get_self_device_time (ev ),
174- count = ev .count / nruns , # average across all runs
197+ count = ev .count / nruns , # type: ignore[operator] # average across all runs
175198 )
176199 all_events [category ].append (profile_ev )
177200
@@ -194,7 +217,7 @@ def add_event(ev, category):
194217
195218 add_event (ev , category )
196219
197- def report_category (category , profile_events ) :
220+ def report_category (category : str , profile_events : list [ ProfileEvent ]) -> float :
198221 if not device_name :
199222 return 0.0
200223
@@ -225,7 +248,7 @@ def report_category(category, profile_events):
225248 )
226249 return total_time
227250
228- def report ():
251+ def report () -> None :
229252 category_list = [
230253 "triton_pointwise" ,
231254 "triton_reduction" ,
@@ -273,8 +296,12 @@ def report():
273296
274297
275298def perf_profile (
276- wall_time_ms , times , repeat , benchmark_name , benchmark_compiled_module_fn
277- ):
299+ wall_time_ms : float ,
300+ times : int ,
301+ repeat : int ,
302+ benchmark_name : str ,
303+ benchmark_compiled_module_fn : BenchmarkCallableType ,
304+ ) -> None :
278305 with torch .profiler .profile (record_shapes = True ) as p :
279306 benchmark_compiled_module_fn (times = times , repeat = repeat )
280307
@@ -289,7 +316,9 @@ def perf_profile(
289316 )
290317
291318
292- def ncu_analyzer (benchmark_name , benchmark_compiled_module_fn ):
319+ def ncu_analyzer (
320+ benchmark_name : str , benchmark_compiled_module_fn : BenchmarkCallableType
321+ ) -> None :
293322 import inspect
294323 import os
295324 import subprocess
@@ -339,7 +368,9 @@ def ncu_analyzer(benchmark_name, benchmark_compiled_module_fn):
339368 return
340369
341370
342- def collect_memory_snapshot (benchmark_compiled_module_fn ):
371+ def collect_memory_snapshot (
372+ benchmark_compiled_module_fn : BenchmarkCallableType ,
373+ ) -> None :
343374 assert torch .cuda .is_available ()
344375
345376 torch .cuda .memory ._record_memory_history (max_entries = 100000 )
@@ -350,7 +381,9 @@ def collect_memory_snapshot(benchmark_compiled_module_fn):
350381 print (f"The collect memory snapshot has been written to { snapshot_path } " )
351382
352383
353- def compiled_module_main (benchmark_name , benchmark_compiled_module_fn ):
384+ def compiled_module_main (
385+ benchmark_name : str , benchmark_compiled_module_fn : BenchmarkCallableType
386+ ) -> None :
354387 """
355388 This is the function called in __main__ block of a compiled module.
356389 """
0 commit comments