1- # mypy: allow-untyped-defs
21import functools
32import itertools
43import operator
54import typing
6- from typing import Callable , Optional , Union
5+ from typing import Any , Callable , List , Optional , Sequence , Tuple , Union
76
87import torch
98import torch ._inductor .runtime .runtime_utils
4443_skip_do_bench_times = False
4544
4645
47- def fetch_fake_tensors (match , kwarg_names ) -> list [Tensor ]:
46+ def fetch_fake_tensors (match : Match , kwarg_names : Sequence [ str ] ) -> list [Tensor ]:
4847 kwargs = match .kwargs
4948 return [kwargs [name ].meta ["val" ] for name in kwarg_names ]
5049
5150
52- def unwrap_fake_args (* arg_names ):
53- def decorator (func ):
54- def wrapper (match ):
51+ def unwrap_fake_args (
52+ * arg_names : str ,
53+ ) -> Callable [[Callable [..., Any ]], Callable [[Match ], Any ]]:
54+ def decorator (func : Callable [..., Any ]) -> Callable [[Match ], Any ]:
55+ def wrapper (match : Match ) -> Any :
5556 fake_tensors = fetch_fake_tensors (match , arg_names )
5657 return func (* fake_tensors )
5758
@@ -116,7 +117,7 @@ def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
116117 )
117118
118119
119- def get_padded_length (x : Union [int , torch .SymInt ], alignment_size ) -> int :
120+ def get_padded_length (x : Union [int , torch .SymInt ], alignment_size : int ) -> int :
120121 # we don't pad x if it is symbolic
121122 if isinstance (x , torch .SymInt ) or alignment_size == 0 or x % alignment_size == 0 :
122123 return 0
@@ -155,11 +156,11 @@ def pad_addmm(
155156 m_padded_length : int ,
156157 k_padded_length : int ,
157158 n_padded_length : int ,
158- beta = 1.0 ,
159- alpha = 1.0 ,
159+ beta : float = 1.0 ,
160+ alpha : float = 1.0 ,
160161 mat1_pre_padded : bool = False ,
161162 mat2_pre_padded : bool = False ,
162- ):
163+ ) -> Tensor :
163164 # for paddings, dim order is reversed for some reasons
164165 # and for every dim, we need to specify left and right padding
165166 if not mat1_pre_padded :
@@ -191,7 +192,11 @@ def pad_addmm(
191192
192193
193194def addmm_replace (
194- input : Optional [Tensor ], mat1 : Tensor , mat2 : Tensor , beta = 1.0 , alpha = 1.0
195+ input : Optional [Tensor ],
196+ mat1 : Tensor ,
197+ mat2 : Tensor ,
198+ beta : float = 1.0 ,
199+ alpha : float = 1.0 ,
195200) -> Tensor :
196201 k_padded_length = get_padded_length (mat1 .shape [1 ], get_alignment_size (mat1 ))
197202 n_padded_length = get_padded_length (mat2 .shape [1 ], get_alignment_size (mat2 ))
@@ -242,42 +247,42 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
242247
243248
244249@functools .lru_cache (None )
245- def get_pad_cache ():
250+ def get_pad_cache () -> torch . _inductor . codecache . LocalCache :
246251 return torch ._inductor .codecache .LocalCache ()
247252
248253
249254def get_cached_should_pad (key : str ) -> bool :
250- return get_pad_cache ().lookup (key )
255+ return get_pad_cache ().lookup (key ) # type: ignore[return-value]
251256
252257
253- def set_cached_should_pad (key : str , value : bool ):
258+ def set_cached_should_pad (key : str , value : bool ) -> None :
254259 return get_pad_cache ().set_value (key , value = value )
255260
256261
257262def get_cached_base_mm_benchmark_time (key : str ) -> float :
258- return get_pad_cache ().lookup (key )
263+ return get_pad_cache ().lookup (key ) # type: ignore[return-value]
259264
260265
261- def set_cached_base_mm_benchmark_time (key : str , value : float ):
266+ def set_cached_base_mm_benchmark_time (key : str , value : float ) -> None :
262267 return get_pad_cache ().set_value (key , value = value )
263268
264269
265270def should_pad_bench_key (
266- match ,
271+ match : Match ,
267272 mat1 : Tensor ,
268273 mat2 : Tensor ,
269- op ,
274+ op : torch . _ops . OpOverloadPacket ,
270275 input : Optional [Tensor ] = None ,
271- is_base_time_key = False ,
276+ is_base_time_key : bool = False ,
272277) -> str :
273- def tensor_key (t ) :
278+ def tensor_key (t : Tensor ) -> Tuple [ torch . Size , Tuple [ int , ...], torch . dtype ] :
274279 return (t .shape , t .stride (), t .dtype )
275280
276281 tf32_key = (
277282 None if mat1 .dtype != torch .float32 else torch .backends .cuda .matmul .allow_tf32
278283 )
279284
280- def fmt_pad (name ) :
285+ def fmt_pad (name : str ) -> Optional [ str ] :
281286 if is_base_time_key :
282287 return None
283288 return f"exclude_pad:{ should_exclude_padding_time (match , name )} "
@@ -298,9 +303,9 @@ def fmt_pad(name):
298303 return key
299304
300305
301- def get_non_view_def (node ) :
306+ def get_non_view_def (node : torch . fx . Node ) -> torch . fx . Node :
302307 if node .op == operator .getitem :
303- return get_non_view_def (node .args [0 ])
308+ return get_non_view_def (node .args [0 ]) # type: ignore[arg-type]
304309
305310 if (
306311 node .op == "call_function"
@@ -312,7 +317,7 @@ def get_non_view_def(node):
312317 return node
313318
314319
315- def should_exclude_padding_time (match , arg_name ) :
320+ def should_exclude_padding_time (match : Match , arg_name : str ) -> bool :
316321 node_def = get_non_view_def (match .kwargs [arg_name ])
317322
318323 # constant padding converts tensors to contiguous so even if the input tensor
@@ -349,7 +354,7 @@ def should_exclude_padding_time(match, arg_name):
349354 return node_def .op != "placeholder"
350355
351356
352- def should_pad (key : str , ori_time , pad_time ) -> bool :
357+ def should_pad (key : str , ori_time : float , pad_time : float ) -> bool :
353358 multiplier = 1.1
354359 # Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
355360 # tradeoff between performance improvement from shape padding and overhead from additional memory ops
@@ -364,7 +369,7 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
364369 return should_pad
365370
366371
367- def should_pad_mm_bf16 (dtype , M , N , K ) :
372+ def should_pad_mm_bf16 (dtype : torch . dtype , M : int , N : int , K : int ) -> bool :
368373 # always force pad for mm with bf16 when the following are satisfied to avoid perf regression
369374 large_k_threshold_to_pad = torch ._inductor .config .post_grad_fusion_options [
370375 "pad_aten_mm_pass"
@@ -381,7 +386,7 @@ def should_pad_mm_bf16(dtype, M, N, K):
381386 return False
382387
383388
384- def should_pad_bench (* args , ** kwargs ) :
389+ def should_pad_bench (* args : Any , ** kwargs : Any ) -> bool :
385390 with dynamo_timed (
386391 "pad_mm_benchmark" ,
387392 log_pt2_compile_event = True ,
@@ -390,7 +395,7 @@ def should_pad_bench(*args, **kwargs):
390395 return _should_pad_bench (* args , ** kwargs )
391396
392397
393- def get_do_bench ():
398+ def get_do_bench () -> Callable [[ Callable [[], Any ]], float ] :
394399 with dynamo_timed ("pad_mm_benchmark_get_do_bench" ):
395400 return functools .partial (
396401 torch ._inductor .runtime .benchmarking .benchmarker .benchmark_gpu ,
@@ -399,7 +404,11 @@ def get_do_bench():
399404
400405
401406def _should_pad_bench (
402- match , mat1 : Tensor , mat2 : Tensor , op , input : Optional [Tensor ] = None
407+ match : Match ,
408+ mat1 : Tensor ,
409+ mat2 : Tensor ,
410+ op : torch ._ops .OpOverloadPacket ,
411+ input : Optional [Tensor ] = None ,
403412) -> bool :
404413 do_bench = get_do_bench ()
405414
@@ -426,7 +435,9 @@ def _should_pad_bench(
426435 if m_padded_length == k_padded_length == n_padded_length == 0 :
427436 return False
428437
429- def realize_symbols (ds ):
438+ def realize_symbols (
439+ ds : Union [torch .Size , Tuple [torch .SymInt , ...]]
440+ ) -> List [int ]:
430441 return [d if isinstance (d , int ) else d .node .hint for d in ds ]
431442
432443 if any (
@@ -614,7 +625,7 @@ def get_context(
614625 m_padded_length : int ,
615626 k_padded_length : int ,
616627 n_padded_length : int ,
617- ):
628+ ) -> AHContext :
618629 context = AHContext ()
619630
620631 context .add_feature ("m" , mat1 .shape [0 ])
@@ -649,14 +660,16 @@ def run_autoheuristic(
649660 m_padded_length : int ,
650661 k_padded_length : int ,
651662 n_padded_length : int ,
652- do_bench ,
663+ do_bench : Callable [[ Callable [[], Any ]], float ] ,
653664 mat1_pre_padded : bool ,
654665 mat2_pre_padded : bool ,
655- ori_time ,
666+ ori_time : float ,
656667 ori_time_key : str ,
657668 key : str ,
658669) -> Optional [bool ]:
659- def feedback_fn (choice : str ):
670+ def feedback_fn (
671+ choice : str ,
672+ ) -> Optional [float ]:
660673 if choice == orig_choice :
661674 return do_bench (orig_bench_fn )
662675 elif choice == pad_choice :
@@ -669,7 +682,7 @@ def fallback() -> str:
669682 orig_choice = "orig"
670683 pad_choice = "pad"
671684 choices = [orig_choice , pad_choice ]
672- feedback = LocalFeedback (feedback_fn )
685+ feedback = LocalFeedback (feedback_fn ) # type: ignore[arg-type]
673686 context = get_context (
674687 mat1 ,
675688 mat2 ,
@@ -718,7 +731,9 @@ def should_pad_mm(match: Match) -> bool:
718731 )
719732
720733
721- def pad_mat1 (mat1 , * , m_padded_length , k_padded_length , is_bmm = False ):
734+ def pad_mat1 (
735+ mat1 : Tensor , * , m_padded_length : int , k_padded_length : int , is_bmm : bool = False
736+ ) -> Tensor :
722737 if k_padded_length != 0 or m_padded_length != 0 :
723738 # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
724739 pad_arg = [0 , k_padded_length , 0 , m_padded_length ]
@@ -729,7 +744,9 @@ def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
729744 return mat1
730745
731746
732- def pad_mat2 (mat2 , * , k_padded_length , n_padded_length , is_bmm = False ):
747+ def pad_mat2 (
748+ mat2 : Tensor , * , k_padded_length : int , n_padded_length : int , is_bmm : bool = False
749+ ) -> Tensor :
733750 if k_padded_length != 0 or n_padded_length != 0 :
734751 # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
735752 pad_arg = [0 , n_padded_length , 0 , k_padded_length ]
@@ -834,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
834851
835852
836853@functools .lru_cache (None )
837- def _pad_mm_init ():
854+ def _pad_mm_init () -> None :
838855 from .joint_graph import patterns
839856
840857 if torch .cuda .is_available ():
0 commit comments