Skip to content

Commit 629fb15

Browse files
BoyuanFengpytorchmergebot
authored andcommitted
[BE] Type annotate pad_mm.py (pytorch#145409)
Pull Request resolved: pytorch#145409 Approved by: https://github.com/Skylion007
1 parent 015c6d6 commit 629fb15

File tree

1 file changed

+55
-38
lines changed

1 file changed

+55
-38
lines changed

torch/_inductor/fx_passes/pad_mm.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
# mypy: allow-untyped-defs
21
import functools
32
import itertools
43
import operator
54
import typing
6-
from typing import Callable, Optional, Union
5+
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
76

87
import torch
98
import torch._inductor.runtime.runtime_utils
@@ -44,14 +43,16 @@
4443
_skip_do_bench_times = False
4544

4645

47-
def fetch_fake_tensors(match, kwarg_names) -> list[Tensor]:
46+
def fetch_fake_tensors(match: Match, kwarg_names: Sequence[str]) -> list[Tensor]:
4847
kwargs = match.kwargs
4948
return [kwargs[name].meta["val"] for name in kwarg_names]
5049

5150

52-
def unwrap_fake_args(*arg_names):
53-
def decorator(func):
54-
def wrapper(match):
51+
def unwrap_fake_args(
52+
*arg_names: str,
53+
) -> Callable[[Callable[..., Any]], Callable[[Match], Any]]:
54+
def decorator(func: Callable[..., Any]) -> Callable[[Match], Any]:
55+
def wrapper(match: Match) -> Any:
5556
fake_tensors = fetch_fake_tensors(match, arg_names)
5657
return func(*fake_tensors)
5758

@@ -116,7 +117,7 @@ def valid_shape_and_stride(t: Optional[Tensor]) -> bool:
116117
)
117118

118119

119-
def get_padded_length(x: Union[int, torch.SymInt], alignment_size) -> int:
120+
def get_padded_length(x: Union[int, torch.SymInt], alignment_size: int) -> int:
120121
# we don't pad x if it is symbolic
121122
if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
122123
return 0
@@ -155,11 +156,11 @@ def pad_addmm(
155156
m_padded_length: int,
156157
k_padded_length: int,
157158
n_padded_length: int,
158-
beta=1.0,
159-
alpha=1.0,
159+
beta: float = 1.0,
160+
alpha: float = 1.0,
160161
mat1_pre_padded: bool = False,
161162
mat2_pre_padded: bool = False,
162-
):
163+
) -> Tensor:
163164
# for paddings, dim order is reversed for some reasons
164165
# and for every dim, we need to specify left and right padding
165166
if not mat1_pre_padded:
@@ -191,7 +192,11 @@ def pad_addmm(
191192

192193

193194
def addmm_replace(
194-
input: Optional[Tensor], mat1: Tensor, mat2: Tensor, beta=1.0, alpha=1.0
195+
input: Optional[Tensor],
196+
mat1: Tensor,
197+
mat2: Tensor,
198+
beta: float = 1.0,
199+
alpha: float = 1.0,
195200
) -> Tensor:
196201
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
197202
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
@@ -242,42 +247,42 @@ def is_mm_compute_bound(M: int, K: int, N: int, dtype: torch.dtype) -> bool:
242247

243248

244249
@functools.lru_cache(None)
245-
def get_pad_cache():
250+
def get_pad_cache() -> torch._inductor.codecache.LocalCache:
246251
return torch._inductor.codecache.LocalCache()
247252

248253

249254
def get_cached_should_pad(key: str) -> bool:
250-
return get_pad_cache().lookup(key)
255+
return get_pad_cache().lookup(key) # type: ignore[return-value]
251256

252257

253-
def set_cached_should_pad(key: str, value: bool):
258+
def set_cached_should_pad(key: str, value: bool) -> None:
254259
return get_pad_cache().set_value(key, value=value)
255260

256261

257262
def get_cached_base_mm_benchmark_time(key: str) -> float:
258-
return get_pad_cache().lookup(key)
263+
return get_pad_cache().lookup(key) # type: ignore[return-value]
259264

260265

261-
def set_cached_base_mm_benchmark_time(key: str, value: float):
266+
def set_cached_base_mm_benchmark_time(key: str, value: float) -> None:
262267
return get_pad_cache().set_value(key, value=value)
263268

264269

265270
def should_pad_bench_key(
266-
match,
271+
match: Match,
267272
mat1: Tensor,
268273
mat2: Tensor,
269-
op,
274+
op: torch._ops.OpOverloadPacket,
270275
input: Optional[Tensor] = None,
271-
is_base_time_key=False,
276+
is_base_time_key: bool = False,
272277
) -> str:
273-
def tensor_key(t):
278+
def tensor_key(t: Tensor) -> Tuple[torch.Size, Tuple[int, ...], torch.dtype]:
274279
return (t.shape, t.stride(), t.dtype)
275280

276281
tf32_key = (
277282
None if mat1.dtype != torch.float32 else torch.backends.cuda.matmul.allow_tf32
278283
)
279284

280-
def fmt_pad(name):
285+
def fmt_pad(name: str) -> Optional[str]:
281286
if is_base_time_key:
282287
return None
283288
return f"exclude_pad:{should_exclude_padding_time(match, name)}"
@@ -298,9 +303,9 @@ def fmt_pad(name):
298303
return key
299304

300305

301-
def get_non_view_def(node):
306+
def get_non_view_def(node: torch.fx.Node) -> torch.fx.Node:
302307
if node.op == operator.getitem:
303-
return get_non_view_def(node.args[0])
308+
return get_non_view_def(node.args[0]) # type: ignore[arg-type]
304309

305310
if (
306311
node.op == "call_function"
@@ -312,7 +317,7 @@ def get_non_view_def(node):
312317
return node
313318

314319

315-
def should_exclude_padding_time(match, arg_name):
320+
def should_exclude_padding_time(match: Match, arg_name: str) -> bool:
316321
node_def = get_non_view_def(match.kwargs[arg_name])
317322

318323
# constant padding converts tensors to contiguous so even if the input tensor
@@ -349,7 +354,7 @@ def should_exclude_padding_time(match, arg_name):
349354
return node_def.op != "placeholder"
350355

351356

352-
def should_pad(key: str, ori_time, pad_time) -> bool:
357+
def should_pad(key: str, ori_time: float, pad_time: float) -> bool:
353358
multiplier = 1.1
354359
# Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
355360
# tradeoff between performance improvement from shape padding and overhead from additional memory ops
@@ -364,7 +369,7 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
364369
return should_pad
365370

366371

367-
def should_pad_mm_bf16(dtype, M, N, K):
372+
def should_pad_mm_bf16(dtype: torch.dtype, M: int, N: int, K: int) -> bool:
368373
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
369374
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
370375
"pad_aten_mm_pass"
@@ -381,7 +386,7 @@ def should_pad_mm_bf16(dtype, M, N, K):
381386
return False
382387

383388

384-
def should_pad_bench(*args, **kwargs):
389+
def should_pad_bench(*args: Any, **kwargs: Any) -> bool:
385390
with dynamo_timed(
386391
"pad_mm_benchmark",
387392
log_pt2_compile_event=True,
@@ -390,7 +395,7 @@ def should_pad_bench(*args, **kwargs):
390395
return _should_pad_bench(*args, **kwargs)
391396

392397

393-
def get_do_bench():
398+
def get_do_bench() -> Callable[[Callable[[], Any]], float]:
394399
with dynamo_timed("pad_mm_benchmark_get_do_bench"):
395400
return functools.partial(
396401
torch._inductor.runtime.benchmarking.benchmarker.benchmark_gpu,
@@ -399,7 +404,11 @@ def get_do_bench():
399404

400405

401406
def _should_pad_bench(
402-
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
407+
match: Match,
408+
mat1: Tensor,
409+
mat2: Tensor,
410+
op: torch._ops.OpOverloadPacket,
411+
input: Optional[Tensor] = None,
403412
) -> bool:
404413
do_bench = get_do_bench()
405414

@@ -426,7 +435,9 @@ def _should_pad_bench(
426435
if m_padded_length == k_padded_length == n_padded_length == 0:
427436
return False
428437

429-
def realize_symbols(ds):
438+
def realize_symbols(
439+
ds: Union[torch.Size, Tuple[torch.SymInt, ...]]
440+
) -> List[int]:
430441
return [d if isinstance(d, int) else d.node.hint for d in ds]
431442

432443
if any(
@@ -614,7 +625,7 @@ def get_context(
614625
m_padded_length: int,
615626
k_padded_length: int,
616627
n_padded_length: int,
617-
):
628+
) -> AHContext:
618629
context = AHContext()
619630

620631
context.add_feature("m", mat1.shape[0])
@@ -649,14 +660,16 @@ def run_autoheuristic(
649660
m_padded_length: int,
650661
k_padded_length: int,
651662
n_padded_length: int,
652-
do_bench,
663+
do_bench: Callable[[Callable[[], Any]], float],
653664
mat1_pre_padded: bool,
654665
mat2_pre_padded: bool,
655-
ori_time,
666+
ori_time: float,
656667
ori_time_key: str,
657668
key: str,
658669
) -> Optional[bool]:
659-
def feedback_fn(choice: str):
670+
def feedback_fn(
671+
choice: str,
672+
) -> Optional[float]:
660673
if choice == orig_choice:
661674
return do_bench(orig_bench_fn)
662675
elif choice == pad_choice:
@@ -669,7 +682,7 @@ def fallback() -> str:
669682
orig_choice = "orig"
670683
pad_choice = "pad"
671684
choices = [orig_choice, pad_choice]
672-
feedback = LocalFeedback(feedback_fn)
685+
feedback = LocalFeedback(feedback_fn) # type: ignore[arg-type]
673686
context = get_context(
674687
mat1,
675688
mat2,
@@ -718,7 +731,9 @@ def should_pad_mm(match: Match) -> bool:
718731
)
719732

720733

721-
def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
734+
def pad_mat1(
735+
mat1: Tensor, *, m_padded_length: int, k_padded_length: int, is_bmm: bool = False
736+
) -> Tensor:
722737
if k_padded_length != 0 or m_padded_length != 0:
723738
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
724739
pad_arg = [0, k_padded_length, 0, m_padded_length]
@@ -729,7 +744,9 @@ def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False):
729744
return mat1
730745

731746

732-
def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False):
747+
def pad_mat2(
748+
mat2: Tensor, *, k_padded_length: int, n_padded_length: int, is_bmm: bool = False
749+
) -> Tensor:
733750
if k_padded_length != 0 or n_padded_length != 0:
734751
# dim order is reversed for constant_pad_nd, for every dim we specify right and left padding
735752
pad_arg = [0, n_padded_length, 0, k_padded_length]
@@ -834,7 +851,7 @@ def bmm_replace(mat1: Tensor, mat2: Tensor) -> Tensor:
834851

835852

836853
@functools.lru_cache(None)
837-
def _pad_mm_init():
854+
def _pad_mm_init() -> None:
838855
from .joint_graph import patterns
839856

840857
if torch.cuda.is_available():

0 commit comments

Comments
 (0)