Skip to content

Commit a3ab27b

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
Migrate from Tuple -> tuple in torch/_inductor (pytorch#144264)
Pull Request resolved: pytorch#144264 Approved by: https://github.com/eellison
1 parent 778d953 commit a3ab27b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+428
-515
lines changed

torch/_inductor/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def aoti_compile_and_package(
130130
def _aoti_compile_and_package_inner(
131131
gm: torch.nn.Module,
132132
# flat_example_inputs: List[Any],
133-
args: Tuple[Any],
133+
args: tuple[Any],
134134
kwargs: Optional[Dict[str, Any]] = None,
135135
*,
136136
load_and_run: bool = False,
@@ -198,7 +198,7 @@ def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type
198198

199199
def aot_compile(
200200
gm: torch.fx.GraphModule,
201-
args: Tuple[Any],
201+
args: tuple[Any],
202202
kwargs: Optional[Dict[str, Any]] = None,
203203
*,
204204
options: Optional[Dict[str, Any]] = None,

torch/_inductor/aoti_eager.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from pathlib import Path
5-
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
from typing import Any, Callable, Dict, List, Optional
66
from unittest import mock
77

88
import torch
@@ -85,7 +85,7 @@ def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
8585
return {int: torch.int32, float: torch.float, bool: torch.bool}
8686

8787

88-
def supported_scalar_types() -> Tuple[type, ...]:
88+
def supported_scalar_types() -> tuple[type, ...]:
8989
type_to_torch_dtype = supported_builtin_dtype_torch_dtype()
9090
return tuple(type_to_torch_dtype.keys())
9191

@@ -170,7 +170,7 @@ def aoti_compile_with_persistent_cache(
170170
device_type: str,
171171
dynamic: bool,
172172
f: Callable[..., Any],
173-
args: Tuple[Any],
173+
args: tuple[Any],
174174
kwargs: Dict[str, Any],
175175
*,
176176
dynamic_shapes: Optional[Dict[str, Any]] = None,

torch/_inductor/autoheuristic/artifacts/_MMRankingA100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def fill_choices(self) -> None:
238238
def get_name(self) -> str:
239239
return 'mm'
240240

241-
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
241+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
242242
if context.get_value('arith_intensity') <= 52.6245059967041:
243243
if context.get_value('n') <= 34.0:
244244
if context.get_value('n') <= 18.0:

torch/_inductor/autoheuristic/artifacts/_MMRankingH100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def fill_choices(self) -> None:
242242
def get_name(self) -> str:
243243
return 'mm'
244244

245-
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
245+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
246246
if context.get_value('arith_intensity') <= 29.89772129058838:
247247
if context.get_value('n') <= 34.0:
248248
if context.get_value('n') <= 18.0:

torch/_inductor/autoheuristic/artifacts/_MixedMMA100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def fill_choices(self) -> None:
6262
def get_name(self) -> str:
6363
return 'mixed_mm'
6464

65-
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
65+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
6666
if str(context.get_value('1LEQmLEQ16')) != 'True':
6767
if context.get_value('m') <= 32.5:
6868
if context.get_value('n') <= 6976.0:

torch/_inductor/autoheuristic/artifacts/_MixedMMH100.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def fill_choices(self) -> None:
6161
def get_name(self) -> str:
6262
return 'mixed_mm'
6363

64-
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
64+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
6565
if context.get_value('arith_intensity') <= 15.988086223602295:
6666
if context.get_value('n') <= 25280.0:
6767
if context.get_value('n') <= 1344.0:

torch/_inductor/autoheuristic/autoheuristic_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import functools
2-
from typing import Any, Callable, Dict, List, Tuple
2+
from typing import Any, Callable, Dict, List
33

44
import torch
55

@@ -64,7 +64,7 @@ def add_feature(
6464
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
6565
self.context_dict[name] = value
6666

67-
def get_numerical_and_categorical_features(self) -> Tuple[List[str], List[str]]:
67+
def get_numerical_and_categorical_features(self) -> tuple[List[str], List[str]]:
6868
numerical_features = []
6969
categorical_features = []
7070
for feature in self.features:
@@ -93,7 +93,7 @@ class AHMetadata:
9393
def __init__(
9494
self,
9595
shared_memory: Any,
96-
device_capa: Tuple[int, int],
96+
device_capa: tuple[int, int],
9797
choices: List[Choice],
9898
name: str,
9999
) -> None:
@@ -327,7 +327,7 @@ def mat2_is_contig_fn(data: Any) -> bool:
327327
return [mat1_is_contig_op, mat2_is_contig_op]
328328

329329

330-
def context_add_strides(context: AHContext, name: str, stride: Tuple[int, ...]) -> None:
330+
def context_add_strides(context: AHContext, name: str, stride: tuple[int, ...]) -> None:
331331
for i, s in enumerate(stride):
332332
context.add_feature(f"{name}_stride_{i}", s)
333333

torch/_inductor/autoheuristic/learnedheuristic_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Tuple
1+
from typing import List, Optional
22

33
from torch._inductor.autoheuristic.autoheuristic_utils import (
44
AHContext,
@@ -88,5 +88,5 @@ def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
8888
choices = [choice for choice in choices if choice is not None]
8989
return choices
9090

91-
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
91+
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
9292
return []

torch/_inductor/codecache.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
NoReturn,
4343
Optional,
4444
Sequence,
45-
Tuple,
4645
TYPE_CHECKING,
4746
TypeVar,
4847
Union,
@@ -375,7 +374,7 @@ def code_hash(code: Union[str, bytes], extra: str = "") -> str:
375374

376375
def get_path(
377376
basename: str, extension: str, specified_dir: str = ""
378-
) -> Tuple[str, str, str]:
377+
) -> tuple[str, str, str]:
379378
if specified_dir:
380379
if os.path.isabs(specified_dir):
381380
subdir = specified_dir
@@ -403,7 +402,7 @@ def write(
403402
extra: str = "",
404403
hash_type: str = "code",
405404
specified_dir: str = "",
406-
) -> Tuple[str, str]:
405+
) -> tuple[str, str]:
407406
# use striped content to compute hash so we don't end up with different
408407
# hashes just because the content begins/ends with different number of
409408
# spaces.
@@ -527,7 +526,7 @@ def __init__(
527526

528527
def _reduce_fake_tensor(
529528
self, t: Tensor
530-
) -> Tuple[Callable[[T], T], Tuple[TensorMetadata]]:
529+
) -> tuple[Callable[[T], T], tuple[TensorMetadata]]:
531530
"""
532531
Custom reducer to pickle FakeTensors.
533532
"""
@@ -537,7 +536,7 @@ def _reduce_fake_tensor(
537536
def _reduce_tensor(
538537
self,
539538
t: Tensor,
540-
) -> Tuple[Callable[[T], T], Tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
539+
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
541540
"""
542541
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
543542
stored as attributes on the GraphModule.
@@ -570,7 +569,7 @@ def _reduce_tensor(
570569
# Otherwise, we just include the metadata.
571570
return (_ident, (metadata,))
572571

573-
def _reduce_symint(self, s: SymInt) -> Tuple[Callable[[T], T], Tuple[str]]:
572+
def _reduce_symint(self, s: SymInt) -> tuple[Callable[[T], T], tuple[str]]:
574573
"""
575574
Custom reducer to pickle SymInts.
576575
"""
@@ -588,7 +587,7 @@ def _reduce_unsupported(self, s: Any) -> NoReturn:
588587

589588
def _reduce_graph_module(
590589
self, gm: torch.fx.GraphModule
591-
) -> Tuple[Any, Tuple[Dict[str, Any], str]]:
590+
) -> tuple[Any, tuple[Dict[str, Any], str]]:
592591
"""
593592
Custom reducer for graph module to handle irrelevant data for user
594593
defined triton kernels
@@ -860,7 +859,7 @@ def compiled_fx_graph_hash(
860859
example_inputs: Sequence[InputType],
861860
fx_kwargs: _CompileFxKwargs,
862861
inputs_to_check: Sequence[int],
863-
) -> Tuple[str, List[str]]:
862+
) -> tuple[str, List[str]]:
864863
"""
865864
Generate a unique hash of the FX graph for caching.
866865
"""
@@ -980,7 +979,7 @@ def _lookup_graph(
980979
local: bool,
981980
remote_cache: Optional[RemoteCache[JsonDataTy]],
982981
constants: CompiledFxGraphConstants,
983-
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
982+
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
984983
"""
985984
Lookup a compiled graph in the cache by key. On a hit, return the
986985
deserialized CompiledFxGraph object. On a miss, return None.
@@ -1222,7 +1221,7 @@ def prepare_key(
12221221
fx_kwargs: _CompileFxKwargs,
12231222
inputs_to_check: Sequence[int],
12241223
remote: bool,
1225-
) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]:
1224+
) -> tuple[Optional[tuple[str, List[str]]], Dict[str, Any]]:
12261225
"""
12271226
Checks that the inductor input is cacheable, then computes
12281227
and returns the cache key for the input.
@@ -1274,7 +1273,7 @@ def load_with_key(
12741273
remote_cache: Optional[RemoteCache[JsonDataTy]],
12751274
is_backward: bool,
12761275
constants: CompiledFxGraphConstants,
1277-
) -> Tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
1276+
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
12781277
"""
12791278
Lookup the graph with the given key, and return results and metadata.
12801279
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@@ -1347,7 +1346,7 @@ def run_command_and_check(cmd_: str) -> None:
13471346

13481347

13491348
@functools.lru_cache(None)
1350-
def split_aot_inductor_output_path(path: str) -> Tuple[str, str]:
1349+
def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
13511350
"""Returns the path where the AOT Inductor compiled kernels are stored."""
13521351
if path.endswith(".so"):
13531352
return os.path.split(path)
@@ -2720,18 +2719,18 @@ class PyCodeCache:
27202719
# than once, but attach different attributes, i.e., due to different
27212720
# constant values.
27222721
modules: List[ModuleType] = []
2723-
linemaps: Dict[str, List[Tuple[Any, ...]]] = {}
2722+
linemaps: Dict[str, List[tuple[Any, ...]]] = {}
27242723

27252724
@classmethod
2726-
def write(cls, source_code: str, extra: str = "") -> Tuple[str, str]:
2725+
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
27272726
return write(source_code, "py", extra=extra)
27282727

27292728
@classmethod
27302729
def load(
27312730
cls,
27322731
source_code: str,
27332732
extra: str = "",
2734-
linemap: Optional[List[Tuple[int, str]]] = None,
2733+
linemap: Optional[List[tuple[int, str]]] = None,
27352734
attrs: Optional[Dict[str, Any]] = None,
27362735
) -> ModuleType:
27372736
key, path = write(source_code, "py", extra=extra)
@@ -2742,7 +2741,7 @@ def load_by_key_path(
27422741
cls,
27432742
key: str,
27442743
path: str,
2745-
linemap: Optional[List[Tuple[int, str]]] = None,
2744+
linemap: Optional[List[tuple[int, str]]] = None,
27462745
attrs: Optional[Dict[str, Any]] = None,
27472746
) -> ModuleType:
27482747
if linemap is None:
@@ -3043,7 +3042,7 @@ class CacheEntry:
30433042
_SOURCE_CODE_SUFFIX = "cu"
30443043

30453044
@classmethod
3046-
def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]:
3045+
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
30473046
"""
30483047
Writes source code into a file with dst_file_ext as the file extension.
30493048
Returns the hash key of source code, and the path to the file.
@@ -3060,7 +3059,7 @@ def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]:
30603059
@classmethod
30613060
def compile(
30623061
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
3063-
) -> Tuple[str, str, str]:
3062+
) -> tuple[str, str, str]:
30643063
"""
30653064
Compiles CUDA source_code into a file with dst_file_ext extension.
30663065
Returns a tuple of dst_file_path, hash_key, source_code_path
@@ -3099,7 +3098,7 @@ def compile(
30993098
return (cls.cache[key].output_path, key, input_path)
31003099

31013100
@classmethod
3102-
def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]:
3101+
def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]:
31033102
"""
31043103
Compiles source code and loads the generated .so file.
31053104
Returns a tuple of DLLWrapper, hash_key, source_code_path
@@ -3129,7 +3128,7 @@ class CacheEntry:
31293128
_logged_compiler_version = False
31303129

31313130
@classmethod
3132-
def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]:
3131+
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
31333132
"""
31343133
Writes source code into a file with dst_file_ext as the file extension.
31353134
Returns the hash key of source code, and the path to the file.
@@ -3146,7 +3145,7 @@ def write(cls, source_code: str, dst_file_ext: str) -> Tuple[str, str]:
31463145
@classmethod
31473146
def compile(
31483147
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
3149-
) -> Tuple[str, str, str]:
3148+
) -> tuple[str, str, str]:
31503149
"""
31513150
Compiles source_code into a file with dst_file_ext extension,
31523151
using the compile command specific for the ROCm platform.
@@ -3194,7 +3193,7 @@ def compile(
31943193
return (cls.cache[key].output_path, key, input_path)
31953194

31963195
@classmethod
3197-
def load(cls, source_code: str, dst_file_ext: str) -> Tuple[DLLWrapper, str, str]:
3196+
def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]:
31983197
"""
31993198
Compiles source code and loads the generated .so file.
32003199
Returns a tuple of DLLWrapper, hash_key, source_code_path

torch/_inductor/codegen/block_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import collections
22
import functools
33
import textwrap
4-
from typing import List, Optional, Tuple
4+
from typing import List, Optional
55

66
import sympy
77
from sympy import Expr, Symbol
@@ -49,7 +49,7 @@ def match_mod_div_block_expr(
4949
index_var: Symbol,
5050
numel: Expr,
5151
num_dims: int,
52-
) -> Optional[Tuple[List[Expr], List[Expr], List[Expr]]]:
52+
) -> Optional[tuple[List[Expr], List[Expr], List[Expr]]]:
5353
"""
5454
Matches modular indexing expressions, converting them to implied block dimensions and strides.
5555
See triton.py for more information.

0 commit comments

Comments
 (0)