Skip to content

Commit 0827464

Browse files
eellisonpytorchmergebot
authored andcommitted
Replace runtime type parameterization (pytorch#155221)
See: ``` >>> import timeit; print(f"OrderedSet[str](): {timeit.timeit('OrderedSet[str]()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s, OrderedSet(): {timeit.timeit('OrderedSet()', setup='from torch.utils._ordered_set import OrderedSet', number=1000000):.6f}s") ``` > `OrderedSet[str]()`: 0.354622s, OrderedSet(): 0.095376s Type parameterization should be on type hint, not in runtime. Pull Request resolved: pytorch#155221 Approved by: https://github.com/Skylion007, https://github.com/jansel
1 parent 7dcc77e commit 0827464

File tree

14 files changed

+55
-53
lines changed

14 files changed

+55
-53
lines changed

torch/_inductor/codegen/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1688,7 +1688,7 @@ def is_removed(self, name: str) -> bool:
16881688
# after you do a call into this kernel, which buffers actually contain
16891689
# updated data? Modeled off of python_argdefs.
16901690
def live_output_buffers(self) -> OrderedSet[str]:
1691-
live_outs = OrderedSet[str]()
1691+
live_outs: OrderedSet[str] = OrderedSet()
16921692
for inplaced in unique(self.inplace_buffers.values()):
16931693
if isinstance(inplaced, RemovedArg):
16941694
continue
@@ -1948,16 +1948,16 @@ def __init__(
19481948
self.num_reduction = 0
19491949

19501950
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
1951-
self.must_keep_buffers = OrderedSet[str]()
1952-
self.store_buffer_names = OrderedSet[str]()
1951+
self.must_keep_buffers: OrderedSet[str] = OrderedSet()
1952+
self.store_buffer_names: OrderedSet[str] = OrderedSet()
19531953
self._load_mask: Optional[str] = None
19541954
self._load_other: Union[None, int, float] = None
19551955
# OrderedSet in set_current_node
19561956
self.current_node: Optional[SchedulerNode] = None
19571957
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
19581958

1959-
self.removed_buffers = OrderedSet[str]()
1960-
self.inplaced_to_remove = OrderedSet[str]()
1959+
self.removed_buffers: OrderedSet[str] = OrderedSet()
1960+
self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
19611961

19621962
# key: the buffer to write
19631963
# value: the buffer to read and whose memory can be reused for
@@ -2144,7 +2144,7 @@ def remove_kernel_local_buffers(self) -> None:
21442144
for buf in self.store_buffer_names
21452145
if buf in scheduler.name_to_buf
21462146
)
2147-
names_to_remove = OrderedSet[str]()
2147+
names_to_remove: OrderedSet[str] = OrderedSet()
21482148
for name in self.store_buffer_names:
21492149
if (
21502150
name not in self.must_keep_buffers

torch/_inductor/codegen/cpp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4897,7 +4897,7 @@ def get_call_ranges(node: BaseSchedulerNode):
48974897
# https://github.com/pytorch/pytorch/blob/1115a25c36340554442f28f9570abd42f0aface2/aten/src/ATen/native/cpu/SoftMaxKernel.cpp#L159 # noqa: B950
48984898
# where the buffer is with size of last dim and contiguous.
48994899
# Only support this typical case at first.
4900-
visited_scheduler_nodes = OrderedSet[str]()
4900+
visited_scheduler_nodes: OrderedSet[str] = OrderedSet()
49014901
for scheduler_node in node.get_nodes():
49024902
# all users inside same OuterLoopFusedSchedulerNode
49034903
assert isinstance(scheduler_node, SchedulerNode)

torch/_inductor/codegen/cpp_gemm_template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1341,7 +1341,7 @@ def get_options(
13411341
reindexers: list[Optional[Callable[[list[Any]], list[Any]]]] = []
13421342
epilogue_creators: list[Callable[[ir.Buffer], ir.Pointwise]] = []
13431343
fake_buffers: list[ir.Buffer] = []
1344-
Y_aliases = OrderedSet[str]()
1344+
Y_aliases: OrderedSet[str] = OrderedSet()
13451345

13461346
use_local_acc = (
13471347
self.layout.dtype != torch.float

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(self):
5656
if not hasattr(self, "device"):
5757
self.device = "cpu"
5858
# must be initialized prior to calling super().__init__()
59-
self.included_devices = OrderedSet[str]()
59+
self.included_devices: OrderedSet[str] = OrderedSet()
6060
super().__init__()
6161
self.declare = "auto "
6262
self.declare_maybe_reference = "decltype(auto) "
@@ -66,14 +66,14 @@ def __init__(self):
6666
self.supports_intermediate_hooks = False
6767
self.kernel_callsite_id = count()
6868
self.int_array_id = count() # for int array local variable declarations
69-
self.declared_int_array_vars = OrderedSet[str]()
69+
self.declared_int_array_vars: OrderedSet[str] = OrderedSet()
7070
self.tmp_tensor_id = count() # for tmp tensor local variable declarations
7171
self.arg_var_id = count()
72-
self.used_cached_devices = OrderedSet[str]()
73-
self.used_cached_dtypes = OrderedSet[str]()
74-
self.used_cached_layouts = OrderedSet[str]()
75-
self.used_cached_memory_formats = OrderedSet[str]()
76-
self.used_cond_predicate = OrderedSet[str]()
72+
self.used_cached_devices: OrderedSet[str] = OrderedSet()
73+
self.used_cached_dtypes: OrderedSet[str] = OrderedSet()
74+
self.used_cached_layouts: OrderedSet[str] = OrderedSet()
75+
self.used_cached_memory_formats: OrderedSet[str] = OrderedSet()
76+
self.used_cond_predicate: OrderedSet[str] = OrderedSet()
7777
self.cached_output_id = count()
7878
self.scalar_to_tensor_id = count()
7979
self.custom_op_wrapper_loaded = False

torch/_inductor/codegen/multi_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def call_kernel(self, kernel_name):
224224

225225
def codegen_nan_check(self):
226226
wrapper = V.graph.wrapper_code
227-
seen = OrderedSet[str]()
227+
seen: OrderedSet[str] = OrderedSet()
228228
for k in self.kernels:
229229
_, call_args, precompile_args, _ = k.args.python_argdefs()
230230
for arg, precompile_arg in zip(call_args, precompile_args):

torch/_inductor/codegen/simd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1259,8 +1259,8 @@ def generate_node_schedule(self, nodes, numel, rnumel):
12591259
done = OrderedSet[scheduler.BaseSchedulerNode]()
12601260
# Writes with a reduced shape, meaning they are only present once the
12611261
# reduction loop has ended
1262-
not_ready_yet_nodes = OrderedSet[str]()
1263-
current_loop_buffer_usage = OrderedSet[str]()
1262+
not_ready_yet_nodes: OrderedSet[str] = OrderedSet()
1263+
current_loop_buffer_usage: OrderedSet[str] = OrderedSet()
12641264
maybe_split_index: Optional[int] = None
12651265

12661266
def fits_in_main_body(n):
@@ -2327,7 +2327,7 @@ def get_tiling_and_scores(
23272327

23282328
return default_tiling, None
23292329

2330-
seen_names = OrderedSet[str]()
2330+
seen_names: OrderedSet[str] = OrderedSet()
23312331
candidate_tiles: Counter[CandidateTiling] = collections.Counter()
23322332
for node in EnableReduction.filter(node_schedule):
23332333
for candidate_tiling in cls.candidate_tilings(node, numel, reduction_numel):

torch/_inductor/codegen/simd_kernel_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def contains_op(self, op_name: str) -> bool:
123123
return bool(self.op_counts().get(op_name))
124124

125125
def get_mutations(self) -> OrderedSet[str]:
126-
mutations = OrderedSet[str]()
126+
mutations: OrderedSet[str] = OrderedSet()
127127
for node in self.scheduler_nodes():
128128
for buf in node.get_outputs():
129129
mutations.update(buf.get_mutations())
@@ -132,7 +132,7 @@ def get_mutations(self) -> OrderedSet[str]:
132132
@cache_on_self
133133
def select_index_dtype(self) -> torch.dtype:
134134
# Gather all used buffer names
135-
buffer_names = OrderedSet[str]()
135+
buffer_names: OrderedSet[str] = OrderedSet()
136136
for node in self.scheduler_nodes():
137137
buffer_names.update(node.get_buffer_names())
138138
buffer_names.update(node.used_buffer_names())

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -749,7 +749,7 @@ class TritonCSEVariable(CSEVariable):
749749
def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None:
750750
super().__init__(name, bounds, dtype)
751751
# We'll use this to track which masks the variable needs when used for indirect indexing
752-
self.mask_vars = OrderedSet[str]()
752+
self.mask_vars: OrderedSet[str] = OrderedSet()
753753
assert dtype is not None, "TritonCSEVariable must have dtype"
754754

755755
def update_on_args(self, name, args, kwargs):
@@ -1769,7 +1769,7 @@ def indexing(
17691769
index_vars = index.free_symbols
17701770
has_rindex = False
17711771

1772-
mask_vars: OrderedSet[str] = OrderedSet[str]()
1772+
mask_vars: OrderedSet[str] = OrderedSet()
17731773
for var in sorted(index_vars, key=operator.attrgetter("name")):
17741774
assert isinstance(var, sympy.Symbol)
17751775
has_rindex = has_rindex or symbol_is_type(
@@ -1811,7 +1811,7 @@ def indexing(
18111811

18121812
have_dense = True
18131813
have_loop_vars = False
1814-
dense_mask_vars = OrderedSet[str]()
1814+
dense_mask_vars: OrderedSet[str] = OrderedSet()
18151815

18161816
for tree in self.active_range_trees():
18171817
if index_vars.intersection(tree.var_list):
@@ -3550,7 +3550,7 @@ def codegen_kernel(self, name=None):
35503550
arg.name, V.graph.sizevars.inv_precomputed_replacements[symbol]
35513551
)
35523552

3553-
mutated_args = OrderedSet[str]()
3553+
mutated_args: OrderedSet[str] = OrderedSet()
35543554
for mutation in self.mutations:
35553555
if mutation in self.args.input_buffers:
35563556
mutated_args.add(self.args.input_buffers[mutation])

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,7 @@ def select_combo_heuristics(
536536
return heuristics_list[0], size_hints_list[0], self.sub_kernels[0]
537537

538538
def get_mutated_args_sub_kernels(self) -> list[str]:
539-
mutated_args = OrderedSet[str]()
539+
mutated_args: OrderedSet[str] = OrderedSet()
540540
for sub_kernel in self.sub_kernels:
541541
for mutation in sub_kernel.mutations:
542542
if mutation in sub_kernel.args.input_buffers:

torch/_inductor/codegen/wrapper.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def writeline(line: str, example_grid: Optional[str] = None):
215215
else:
216216
assert len(grids) > 1
217217
assert len(grids) == len(configs)
218-
seen = OrderedSet[str]()
218+
seen: OrderedSet[str] = OrderedSet()
219219
# sort the configs from the largest # of kwargs to the smallest to
220220
# emit the grids in the order of (approximately) decreasing specificity
221221
# TODO(aakhundov): the sorting below is generally not sufficient, so
@@ -857,7 +857,7 @@ def __init__(self):
857857
self.kernel_autotune_defs = IndentedBuffer()
858858
self.kernel_autotune_calls = IndentedBuffer()
859859
self.subgraph_definitions = IndentedBuffer()
860-
self.kernel_autotune_names = OrderedSet[str]()
860+
self.kernel_autotune_names: OrderedSet[str] = OrderedSet()
861861
# Map key is the kernel argument name; value is a tuple of the resulting example
862862
# tensor name with the kernel where that tensor was most recently used.
863863
self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
@@ -877,7 +877,9 @@ def __init__(self):
877877
self.last_seen_device_guard_index: Optional[int] = None
878878
self.supports_intermediate_hooks = True
879879
self.user_defined_kernel_cache: dict[tuple[Any, ...], tuple[str, Any]] = {}
880-
self.unbacked_symbol_decls = OrderedSet[str]() # str of sympy.Symbol
880+
self.unbacked_symbol_decls: OrderedSet[str] = (
881+
OrderedSet()
882+
) # str of sympy.Symbol
881883
self.computed_sizes: OrderedSet[sympy.Symbol] = OrderedSet()
882884
self.launcher_fn_name = None
883885
# This function can be overridden to change the launcher name
@@ -921,9 +923,9 @@ def add_import_once(line: str) -> None:
921923

922924
self.add_import_once = add_import_once
923925
self._metas: dict[str, str] = {}
924-
self._meta_vars = OrderedSet[str]()
926+
self._meta_vars: OrderedSet[str] = OrderedSet()
925927
self.multi_kernel_state = MultiKernelState()
926-
self.already_codegened_subgraphs = OrderedSet[str]()
928+
self.already_codegened_subgraphs: OrderedSet[str] = OrderedSet()
927929
self.allocated_workspaces: dict[str, Any] = {}
928930

929931
# intermediate tensor value printing utility

0 commit comments

Comments
 (0)