Skip to content

Commit 4516c59

Browse files
anijain2305pytorchmergebot
authored andcommitted
[dynamo][source] Add special source for __code__ and __closure__ (pytorch#159722)
Pull Request resolved: pytorch#159722 Approved by: https://github.com/jansel
1 parent 8bc843a commit 4516c59

File tree

4 files changed

+261
-5
lines changed

4 files changed

+261
-5
lines changed

torch/_dynamo/guards.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@
105105
CallFunctionNoArgsSource,
106106
CallMethodItemSource,
107107
ChainedSource,
108+
ClosureSource,
109+
CodeSource,
108110
ConstantSource,
109111
ConstDictKeySource,
110112
DataclassFieldsSource,
@@ -1495,6 +1497,20 @@ def get_guard_manager_from_source(self, source):
14951497
example_value=example_value,
14961498
guard_manager_enum=guard_manager_enum,
14971499
)
1500+
elif istype(source, CodeSource):
1501+
assert base_guard_manager # to make mypy happy
1502+
out = base_guard_manager.code_manager(
1503+
source=source_name,
1504+
example_value=example_value,
1505+
guard_manager_enum=guard_manager_enum,
1506+
)
1507+
elif istype(source, ClosureSource):
1508+
assert base_guard_manager # to make mypy happy
1509+
out = base_guard_manager.closure_manager(
1510+
source=source_name,
1511+
example_value=example_value,
1512+
guard_manager_enum=guard_manager_enum,
1513+
)
14981514
else:
14991515
raise AssertionError(
15001516
f"missing guard manager builder {source} - {source.name()}"
@@ -1568,7 +1584,10 @@ def arg_ref(self, guard: Union[str, Guard]) -> str:
15681584
return name
15691585

15701586
def _guard_on_attribute(self, guard: Guard, attr_name: str, guard_fn):
1571-
attr_source = AttrSource(guard.originating_source, attr_name)
1587+
if attr_name == "__code__":
1588+
attr_source = CodeSource(guard.originating_source)
1589+
else:
1590+
attr_source = AttrSource(guard.originating_source, attr_name) # type: ignore[assignment]
15721591
# Copy the stack info
15731592
new_guard = Guard(
15741593
attr_source, guard_fn, stack=guard.stack, user_stack=guard.user_stack
@@ -1580,6 +1599,9 @@ def HASATTR(self, guard: Guard):
15801599
source = guard.originating_source
15811600
if isinstance(source, NNModuleSource):
15821601
source = source.base
1602+
if isinstance(source, CodeSource):
1603+
# No need to guard that a function has a __code__ attribute
1604+
return
15831605
assert isinstance(source, AttrSource), f"invalid source {guard.name}"
15841606
base_source = source.base
15851607
base = base_source.name()

torch/_dynamo/source.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,34 @@ def reconstruct(self, codegen: "PyCodegen") -> None:
285285
# local cell object should never be used for guards.
286286

287287

288+
# Represents obj.__code__ where object is type object
289+
@dataclasses.dataclass(frozen=True)
290+
class CodeSource(ChainedSource):
291+
def reconstruct(self, codegen: "PyCodegen") -> None:
292+
codegen(self.base)
293+
codegen.extend_output(codegen.create_load_attrs("__code__"))
294+
295+
def guard_source(self) -> GuardSource:
296+
return self.base.guard_source()
297+
298+
def name(self) -> str:
299+
return f"{self.base.name()}.__code__"
300+
301+
302+
# Represents obj.__closure__ where object is type object
303+
@dataclasses.dataclass(frozen=True)
304+
class ClosureSource(ChainedSource):
305+
def reconstruct(self, codegen: "PyCodegen") -> None:
306+
codegen(self.base)
307+
codegen.extend_output(codegen.create_load_attrs("__closure__"))
308+
309+
def guard_source(self) -> GuardSource:
310+
return self.base.guard_source()
311+
312+
def name(self) -> str:
313+
return f"{self.base.name()}.__closure__"
314+
315+
288316
# Represents tensor.grad source. It could be represented by AttrSource as well.
289317
# But, we could access grad field on tensor directly in C++ without going
290318
# through the Python bytecodes. Therefore, we use a separate source for grad

torch/_dynamo/variables/functions.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,13 @@
5656
Unsupported,
5757
)
5858
from ..guards import GuardBuilder, install_guard
59-
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
59+
from ..source import (
60+
AttrSource,
61+
ClosureSource,
62+
ConstantSource,
63+
DefaultsSource,
64+
GetItemSource,
65+
)
6066
from ..utils import (
6167
check_constant_args,
6268
check_unspec_or_constant_args,
@@ -436,9 +442,7 @@ def bind_args(self, parent, args, kwargs) -> dict[str, VariableTracker]:
436442
cell_var = side_effects[cell]
437443

438444
elif self.source:
439-
closure_cell = GetItemSource(
440-
AttrSource(self.source, "__closure__"), idx
441-
)
445+
closure_cell = GetItemSource(ClosureSource(self.source), idx)
442446
closure_cell_contents = AttrSource(closure_cell, "cell_contents")
443447
try:
444448
contents_var = VariableTracker.build(

torch/csrc/dynamo/guards.cpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5786,6 +5786,158 @@ class WeakRefCallGuardAccessor : public GuardAccessor {
57865786
void clone_visitor(WeakRefCallGuardAccessor* to) {}
57875787
};
57885788

5789+
/**
5790+
* Represent x.__code__
5791+
*/
5792+
class CodeGuardAccessor : public GuardAccessor {
5793+
public:
5794+
// name = __type_mro_accessor__, a unique string used as attribute name.
5795+
CodeGuardAccessor(
5796+
RootGuardManager* root,
5797+
py::str name,
5798+
std::string source,
5799+
py::handle example_value,
5800+
py::handle guard_manager_enum)
5801+
: GuardAccessor(
5802+
root,
5803+
std::move(name),
5804+
std::move(source),
5805+
example_value,
5806+
guard_manager_enum) {}
5807+
5808+
// NB: Intentional duplication between check_nopybind and
5809+
// check_verbose_nopybind.
5810+
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
5811+
override { // borrowed ref
5812+
PyObject* func = obj;
5813+
if (PyMethod_Check(obj)) {
5814+
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
5815+
} else if (PyInstanceMethod_Check(obj)) {
5816+
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
5817+
}
5818+
PyObject* x = PyFunction_GetCode(func); // borrowed ref
5819+
if (x == nullptr) {
5820+
PyErr_Clear();
5821+
return false;
5822+
}
5823+
return _guard_manager->check_nopybind(x);
5824+
}
5825+
5826+
GuardDebugInfo check_verbose_nopybind(
5827+
PyObject* obj) override { // borrowed ref
5828+
PyObject* func = obj;
5829+
if (PyMethod_Check(obj)) {
5830+
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
5831+
} else if (PyInstanceMethod_Check(obj)) {
5832+
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
5833+
}
5834+
PyObject* x = PyFunction_GetCode(func);
5835+
if (x == nullptr) {
5836+
PyErr_Clear();
5837+
return GuardDebugInfo(
5838+
false,
5839+
std::string(repr() + ": Not a function on ") + get_source(),
5840+
0);
5841+
}
5842+
5843+
return _guard_manager->check_verbose_nopybind(x);
5844+
}
5845+
5846+
std::string repr() const override {
5847+
return "CodeGuardAccessor";
5848+
}
5849+
5850+
public: // cloning functions
5851+
CodeGuardAccessor(GuardManager* guard_manager, CodeGuardAccessor* from)
5852+
: GuardAccessor(guard_manager, from) {
5853+
from->clone_visitor(this);
5854+
}
5855+
5856+
GuardAccessor* clone(
5857+
RootGuardManager* cloned_root,
5858+
const py::function& clone_filter_fn) override {
5859+
return clone_common<CodeGuardAccessor>(cloned_root, clone_filter_fn);
5860+
}
5861+
5862+
void clone_visitor(CodeGuardAccessor* to) {}
5863+
};
5864+
5865+
/**
5866+
* Represent x.__closure__
5867+
*/
5868+
class ClosureGuardAccessor : public GuardAccessor {
5869+
public:
5870+
// name = __type_mro_accessor__, a unique string used as attribute name.
5871+
ClosureGuardAccessor(
5872+
RootGuardManager* root,
5873+
py::str name,
5874+
std::string source,
5875+
py::handle example_value,
5876+
py::handle guard_manager_enum)
5877+
: GuardAccessor(
5878+
root,
5879+
std::move(name),
5880+
std::move(source),
5881+
example_value,
5882+
guard_manager_enum) {}
5883+
5884+
// NB: Intentional duplication between check_nopybind and
5885+
// check_verbose_nopybind.
5886+
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
5887+
override { // borrowed ref
5888+
PyObject* func = obj;
5889+
if (PyMethod_Check(obj)) {
5890+
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
5891+
} else if (PyInstanceMethod_Check(obj)) {
5892+
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
5893+
}
5894+
PyObject* x = PyFunction_GetClosure(func); // borrowed ref
5895+
if (x == nullptr) {
5896+
PyErr_Clear();
5897+
return false;
5898+
}
5899+
return _guard_manager->check_nopybind(x);
5900+
}
5901+
5902+
GuardDebugInfo check_verbose_nopybind(
5903+
PyObject* obj) override { // borrowed ref
5904+
PyObject* func = obj;
5905+
if (PyMethod_Check(obj)) {
5906+
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
5907+
} else if (PyInstanceMethod_Check(obj)) {
5908+
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
5909+
}
5910+
PyObject* x = PyFunction_GetClosure(func);
5911+
if (x == nullptr) {
5912+
PyErr_Clear();
5913+
return GuardDebugInfo(
5914+
false,
5915+
std::string(repr() + ": Not a function on ") + get_source(),
5916+
0);
5917+
}
5918+
5919+
return _guard_manager->check_verbose_nopybind(x);
5920+
}
5921+
5922+
std::string repr() const override {
5923+
return "ClosureGuardAccessor";
5924+
}
5925+
5926+
public: // cloning functions
5927+
ClosureGuardAccessor(GuardManager* guard_manager, ClosureGuardAccessor* from)
5928+
: GuardAccessor(guard_manager, from) {
5929+
from->clone_visitor(this);
5930+
}
5931+
5932+
GuardAccessor* clone(
5933+
RootGuardManager* cloned_root,
5934+
const py::function& clone_filter_fn) override {
5935+
return clone_common<ClosureGuardAccessor>(cloned_root, clone_filter_fn);
5936+
}
5937+
5938+
void clone_visitor(ClosureGuardAccessor* to) {}
5939+
};
5940+
57895941
/**
57905942
* Implements function call no args - e.g, torch.cuda.current_device()
57915943
*/
@@ -6451,6 +6603,16 @@ PyObject* torch_c_dynamo_guards_init() {
64516603
std::unique_ptr<TupleIteratorGetItemAccessor>>(
64526604
py_m, "TupleIteratorGetItemAccessor");
64536605
// NOLINTNEXTLINE(bugprone-unused-raii)
6606+
py::class_<
6607+
CodeGuardAccessor,
6608+
GuardAccessor,
6609+
std::unique_ptr<CodeGuardAccessor>>(py_m, "CodeGuardAccessor");
6610+
// NOLINTNEXTLINE(bugprone-unused-raii)
6611+
py::class_<
6612+
ClosureGuardAccessor,
6613+
GuardAccessor,
6614+
std::unique_ptr<ClosureGuardAccessor>>(py_m, "ClosureGuardAccessor");
6615+
// NOLINTNEXTLINE(bugprone-unused-raii)
64546616
py::class_<
64556617
GlobalWeakRefGuardAccessor,
64566618
GuardAccessor,
@@ -6971,6 +7133,46 @@ PyObject* torch_c_dynamo_guards_init() {
69717133
py::return_value_policy::reference)
69727134
// return by reference because GuardManager has the ownership of accessors
69737135
// and guard managers
7136+
.def(
7137+
"code_manager",
7138+
[](GuardManager& self,
7139+
std::string source,
7140+
py::handle example_value,
7141+
py::handle guard_manager_enum) -> GuardManager* {
7142+
// A unique key is used to save as the accessor key.
7143+
py::str unique_key("__code_accessor__");
7144+
return self.get_child_manager<CodeGuardAccessor>(
7145+
std::move(unique_key),
7146+
std::move(source),
7147+
example_value,
7148+
guard_manager_enum);
7149+
},
7150+
py::arg("source"),
7151+
py::arg("example_value"),
7152+
py::arg("guard_manager_enum"),
7153+
py::return_value_policy::reference)
7154+
// return by reference because GuardManager has the ownership of accessors
7155+
// and guard managers
7156+
.def(
7157+
"closure_manager",
7158+
[](GuardManager& self,
7159+
std::string source,
7160+
py::handle example_value,
7161+
py::handle guard_manager_enum) -> GuardManager* {
7162+
// A unique key is used to save as the accessor key.
7163+
py::str unique_key("__closure_accessor__");
7164+
return self.get_child_manager<ClosureGuardAccessor>(
7165+
std::move(unique_key),
7166+
std::move(source),
7167+
example_value,
7168+
guard_manager_enum);
7169+
},
7170+
py::arg("source"),
7171+
py::arg("example_value"),
7172+
py::arg("guard_manager_enum"),
7173+
py::return_value_policy::reference)
7174+
// return by reference because GuardManager has the ownership of accessors
7175+
// and guard managers
69747176
.def(
69757177
"global_weakref_manager",
69767178
&GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,

0 commit comments

Comments
 (0)