Skip to content

Commit 2ac4140

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][dicts] Guarding lazily on dict keys (pytorch#143997)
Pull Request resolved: pytorch#143997 Approved by: https://github.com/jansel
1 parent e05d677 commit 2ac4140

File tree

8 files changed

+215
-139
lines changed

8 files changed

+215
-139
lines changed

benchmarks/dynamo/pr_time_benchmarks/expected_results.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,27530000000,0.015
1818

1919

2020

21-
basic_modules_ListOfLinears_eager,compile_time_instruction_count,945667911,0.015
21+
basic_modules_ListOfLinears_eager,compile_time_instruction_count,928600000,0.015
2222

2323

2424

test/dynamo/test_dicts.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,31 @@ def fn(x):
6262
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
6363
self.assertEqual(fn(x), opt_fn(x))
6464

65+
def test_dict_contains(self):
66+
sd = dict()
67+
sd[2] = 5
68+
sd[4] = 10
69+
70+
def fn(x):
71+
if 1 in sd:
72+
x = x * 2
73+
else:
74+
x = x * 3
75+
return x
76+
77+
x = torch.randn(4)
78+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
79+
self.assertEqual(fn(x), opt_fn(x))
80+
81+
# Ensure a recompilation
82+
sd[1] = 15
83+
self.assertEqual(fn(x), opt_fn(x))
84+
85+
# Ensure not recompilation because the traced program remains same here.
86+
sd[2] = 10
87+
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
88+
self.assertEqual(fn(x), opt_fn(x))
89+
6590
def test_dict_subclass_methods_fallback_readonly(self):
6691
sd = SimpleDict()
6792
sd[2] = 5
@@ -318,6 +343,55 @@ def fn(x, d):
318343
x = torch.randn(4)
319344
self.assertEqual(opt_fn(x, d), fn(x, d))
320345

346+
def test_lazy_key_guarding(self):
347+
d = {"a": 2, "b": 3, "c": 5}
348+
349+
def fn(x):
350+
return x * d["a"]
351+
352+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
353+
354+
x = torch.randn(4)
355+
ref = fn(x)
356+
res = opt_fn(x)
357+
self.assertEqual(ref, res)
358+
359+
# Since key c was not used, it should not lead to a recompilation
360+
d.pop("c")
361+
d["d"] = 10
362+
363+
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
364+
ref = fn(x)
365+
res = opt_fn(x)
366+
self.assertEqual(ref, res)
367+
368+
def test_lazy_key_non_const_guarding(self):
369+
d = {
370+
list: 2,
371+
dict: 3,
372+
OrderedDict: 5,
373+
namedtuple: 7,
374+
}
375+
376+
def fn(x):
377+
return x * d[list]
378+
379+
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
380+
381+
x = torch.randn(4)
382+
ref = fn(x)
383+
res = opt_fn(x)
384+
self.assertEqual(ref, res)
385+
386+
# Since key c was not used, it should not lead to a recompilation
387+
d.pop(dict)
388+
d[defaultdict] = 10
389+
390+
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
391+
ref = fn(x)
392+
res = opt_fn(x)
393+
self.assertEqual(ref, res)
394+
321395
def test_dict_mutation_side_effect(self):
322396
def fn(d):
323397
d["c"] = d["a"] + d.pop("b")

test/dynamo/test_misc.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,6 @@ def fn(x, y):
11231123
L['x'].requires_grad == False
11241124
L['x'].size()[1] == L['x'].size()[0]
11251125
L['x'].storage_offset() == 0
1126-
___dict_contains('builtins', G['sys'].modules)
11271126
___dict_contains('operator', G['sys'].modules)
11281127
___dict_contains('operator', G['sys'].modules)
11291128
hasattr(L['x'], '_dynamo_dynamic_indices') == False

torch/_dynamo/guards.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@
123123
builtin_dict_keys,
124124
common_constant_types,
125125
dict_keys,
126-
dict_keys_repr,
127126
get_custom_getattr,
128127
get_torch_function_mode_stack,
129128
get_torch_function_mode_stack_at,
@@ -422,7 +421,7 @@ def _get_closure_vars():
422421
"___odict_getitem": collections.OrderedDict.__getitem__,
423422
"___key_to_id": key_to_id,
424423
"___dict_version": dict_version,
425-
"___dict_contains": lambda a, b: a in b,
424+
"___dict_contains": lambda a, b: dict.__contains__(b, a),
426425
"___tuple_iterator_len": tuple_iterator_len,
427426
"___normalize_range_iter": normalize_range_iter,
428427
"___tuple_iterator_getitem": tuple_iterator_getitem,
@@ -1732,29 +1731,6 @@ def DUPLICATE_INPUT(self, guard, source_b):
17321731
get_verbose_code_parts(code, guard),
17331732
)
17341733

1735-
def DICT_KEYS(self, guard):
1736-
# Guard on the keys and their order
1737-
ref = self.arg_ref(guard)
1738-
value = self.get(guard.name)
1739-
1740-
self.TYPE_MATCH(guard)
1741-
code = []
1742-
any_key_is_id = any(key_is_id(k) for k in builtin_dict_keys(value))
1743-
const_keys_repr = dict_keys_repr(
1744-
key_to_id(value),
1745-
local=is_from_local_source(guard.originating_source),
1746-
)
1747-
if any_key_is_id:
1748-
code.append(f"___key_to_id({ref}) == {const_keys_repr}")
1749-
else:
1750-
code.append(f"list({ref}.keys()) == {const_keys_repr}")
1751-
1752-
self._set_guard_export_info(guard, code)
1753-
if self.requires_key_order_guarding(guard.originating_source):
1754-
self.guard_on_dict_keys_and_order(value, guard)
1755-
else:
1756-
self.guard_on_dict_keys_and_ignore_order(value, guard)
1757-
17581734
def WEAKREF_ALIVE(self, guard):
17591735
code = [f"{self.arg_ref(guard)} is not None"]
17601736

@@ -1763,11 +1739,18 @@ def WEAKREF_ALIVE(self, guard):
17631739
get_verbose_code_parts(code, guard)
17641740
)
17651741

1766-
def DICT_CONST_KEYS(self, guard):
1767-
"""Constant keys match"""
1742+
def DICT_KEYS_MATCH(self, guard):
1743+
"""Insert guard to check that the keys of a dict are same"""
17681744
ref = self.arg_ref(guard)
17691745
value = self.get(guard.name)
17701746

1747+
if value is torch.utils._pytree.SUPPORTED_NODES:
1748+
# For SUPPORTED_NODES, we can guard on the dictionary version (PEP509).
1749+
self.DICT_VERSION(guard)
1750+
return
1751+
1752+
self.SEQUENCE_LENGTH(guard)
1753+
17711754
code = []
17721755
# Ensure that we call dict.keys and not value.keys (which can call
17731756
# overridden keys method). In the C++ guards, we relied on PyDict_Next

torch/_dynamo/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DefaultDictVariable,
2727
DictKeySetVariable,
2828
FrozensetVariable,
29+
NNModuleHooksDictVariable,
2930
SetVariable,
3031
)
3132
from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable

torch/_dynamo/variables/builder.py

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
import operator
1515
import random
1616
import re
17-
import sys
1817
import types
1918
import warnings
2019
import weakref
@@ -142,7 +141,6 @@
142141
DefaultDictVariable,
143142
DictKeySetVariable,
144143
FrozensetVariable,
145-
PythonSysModulesVariable,
146144
SetVariable,
147145
)
148146
from .distributed import (
@@ -574,37 +572,15 @@ def create_2d_tma_descriptor():
574572
output, tuple_cls=type(value), source=self.source
575573
)
576574
return result
577-
elif value is torch.utils._pytree.SUPPORTED_NODES:
578-
# For SUPPORTED_NODES, we guard on the dictionary version (PEP509)
579-
# under the assumption that the values themselves don't change.
580-
self.install_guards(GuardBuilder.DICT_VERSION)
581-
582-
# The keys on the SUPPORTED_NODES can be arbitrary, so save on the
583-
# key order.
584-
self.tx.output.guard_on_key_order.add(self.source.name())
585-
result = {
586-
TypingVariable(k): UserDefinedObjectVariable(
587-
v,
588-
source=DictGetItemSource(
589-
self.get_source(), ConstDictKeySource(self.get_source(), i)
590-
),
591-
)
592-
for i, (k, v) in enumerate(value.items())
593-
}
594-
return ConstDictVariable(result, type(value))
595-
elif value is sys.modules:
596-
self.install_guards(GuardBuilder.FUNCTION_MATCH)
597-
return PythonSysModulesVariable(source=self.source)
598575
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
599-
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
600-
601-
# Optimisation for the common case strings, ints, etc
576+
self.install_guards(GuardBuilder.TYPE_MATCH)
602577
all_const = all(ConstantVariable.is_literal(k) for k in value.keys())
603-
if all_const:
604-
# TODO(anijain2305) - Do we have to guard on all the keys? Can
605-
# keys be guarded lazily, similar to values?
606-
self.install_guards(GuardBuilder.DICT_CONST_KEYS)
607-
else:
578+
579+
# For all_const, we dont have to guard on anything yet. We guard on
580+
# keys lazily by adding a dict_getitem entry for each accessed key.
581+
# For cases where we need to guard on all keys, we lazily put guards
582+
# during the dict call_method (check dicts.py)
583+
if not all_const:
608584
# Guard on the key order
609585
# This is not ideal, i.e., there is no need to guard on the key
610586
# order. But we guard on the key order because of the complexity
@@ -725,7 +701,7 @@ def build_key_value(i, k, v):
725701

726702
install_guard(
727703
self.get_source().make_guard(GuardBuilder.TYPE_MATCH),
728-
keywords_source.make_guard(GuardBuilder.DICT_KEYS),
704+
keywords_source.make_guard(GuardBuilder.DICT_KEYS_MATCH),
729705
args_source.make_guard(GuardBuilder.SEQUENCE_LENGTH),
730706
)
731707
return FunctoolsPartialVariable(func_obj, args, keywords)

0 commit comments

Comments
 (0)