Skip to content

Commit fac7468

Browse files
xmfanpytorchmergebot
authored andcommitted
[compiled autograd] fix node origin graph comments (pytorch#139003)
the comment update was done after prehooks were already collected, so prehooks would appear as part of the previous node Pull Request resolved: pytorch#139003 Approved by: https://github.com/yf225
1 parent f9ae3fa commit fac7468

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

torch/csrc/dynamo/python_compiled_autograd.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,19 @@ CacheNode* _compiled_autograd_impl(
647647

648648
for (size_t i = 0; i < calls.size(); i++) {
649649
NodeCall& call = *calls[i];
650+
651+
std::string _node_name = call.node->name();
652+
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
653+
TORCH_INTERNAL_ASSERT(node_name != nullptr);
654+
THPObjectPtr set_node_origin(
655+
PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
656+
PyObject* pyobj = Py_None;
657+
if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
658+
pyobj = pynode->obj;
659+
}
660+
check(PyObject_CallFunction(
661+
set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
662+
650663
// TODO(jansel): consider adding some of this stuff:
651664
// guard(local_graph_task); NodeGuard ndguard(task.fn_); const auto
652665
// opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
@@ -692,20 +705,6 @@ CacheNode* _compiled_autograd_impl(
692705
inputs = THPVariable_UnpackList(pyinputs);
693706
}
694707

695-
std::string _node_name = call.node->name();
696-
THPObjectPtr node_name(PyUnicode_FromString(_node_name.data()));
697-
TORCH_INTERNAL_ASSERT(node_name != nullptr);
698-
THPObjectPtr set_node_origin(
699-
PyObject_GetAttrString(py_compiler.get(), "set_node_origin"));
700-
701-
PyObject* pyobj = Py_None;
702-
if (auto pynode = std::dynamic_pointer_cast<PyNode>(call.node)) {
703-
pyobj = pynode->obj;
704-
}
705-
706-
check(PyObject_CallFunction(
707-
set_node_origin, "OIO", node_name.get(), i, pyobj, nullptr));
708-
709708
SwapSavedVariables saved(compiler_call, state, py_compiler.get(), call);
710709
variable_list outputs = call.node->apply_with_saved(inputs, saved);
711710

0 commit comments

Comments
 (0)