Skip to content

Commit 7482523

Browse files
xmfanpytorchmergebot
authored andcommitted
[ca] introduce RuntimeState to support c++ hooks via graph breaks (pytorch#149987)
Pull Request resolved: pytorch#149987 Approved by: https://github.com/jansel ghstack dependencies: pytorch#149647, pytorch#149709, pytorch#149651, pytorch#149897
1 parent dcb378c commit 7482523

File tree

7 files changed

+143
-23
lines changed

7 files changed

+143
-23
lines changed

test/inductor/test_compiled_autograd.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4146,6 +4146,17 @@ def wrap_test_class(orig_cls):
41464146
"test_checkpointing_without_reentrant_memory_savings", # reentrant .backward
41474147
"test_dtensor_basic", # torch._dynamo.exc.Unsupported: Failed to convert args/kwargs to proxy
41484148
"test_dtensor_contiguous_dtensor_noncontiguous_local_as_tangent", # subclass constructor
4149+
"test_retain_grad", # retains_grad_hooks
4150+
"test_retain_grad_cycle", # retains_grad_hooks
4151+
"test_retain_grad_inplace", # retains_grad_hooks
4152+
"test_retain_grad_inplace_over_view", # retains_grad_hooks
4153+
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
4154+
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
4155+
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
4156+
"test_multi_grad_all_hooks", # retains_grad_hooks
4157+
"test_prehook_ordering", # retains_grad_hooks
4158+
"test_will_engine_execute_node", # retains_grad_hooks
4159+
"test_backward_to_node", # retains_grad_hooks
41494160
}
41504161

41514162
test_contexts = {
@@ -4173,11 +4184,6 @@ def wrap_test_class(orig_cls):
41734184
"test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd
41744185
"test_current_node", # TorchDispatchMode not yet implemented for compiled autograd
41754186
"test_post_accumulate_grad_hook_ordering", # accuracy error
4176-
"test_retain_grad_cycle", # retains_grad_hooks
4177-
"test_retain_grad_inplace", # retains_grad_hooks
4178-
"test_retain_grad_inplace_over_view", # retains_grad_hooks
4179-
"test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks
4180-
"test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks
41814187
"test_accumulate_grad", # create_graph
41824188
"test_anomaly_assign_parent_cleanup", # create_graph
41834189
"test_backward_create_graph_warns", # create_graph
@@ -4198,19 +4204,13 @@ def wrap_test_class(orig_cls):
41984204
"test_grad_nonleaf", # create_graph
41994205
"test_grad_nonleaf_many_outputs", # create_graph
42004206
"test_hessian_vector", # create_graph
4201-
"test_hook_edge_case_when_called_with_grad", # retains_grad_hooks
42024207
"test_inplace_on_view_backward", # create_graph
42034208
"test_multi_grad_any_hooks", # register_multi_grad_hook
4204-
"test_multi_grad_all_hooks", # retains_grad_hooks
42054209
"test_nested_anomaly_detect_nan", # create_graph
42064210
"test_nested_anomaly_printstack_cleanup", # create_graph
42074211
"test_once_differentiable", # create_graph
4208-
"test_prehook_ordering", # retains_grad_hooks
4209-
"test_retain_grad", # retains_grad_hooks
42104212
"test_saved_variable_packing_unpacking_saved_original_with_hooks", # create_graph
42114213
"test_select_sum", # create_graph, also needs graph breaks
4212-
"test_will_engine_execute_node", # retains_grad_hooks
4213-
"test_backward_to_node", # retains_grad_hooks NYI
42144214
"test_custom_autograd_no_early_free", # create_graph
42154215
"test_custom_function_error", # vjp
42164216
"test_custom_function_save_for_forward", # vjp

torch/_C/_dynamo/compiled_autograd.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Callable
22

3+
from torch import Tensor
34
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
45

56
def set_autograd_compiler(
@@ -9,3 +10,4 @@ def set_autograd_compiler(
910
def clear_cache() -> None: ...
1011
def is_cache_empty() -> bool: ...
1112
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
13+
def call_cpp_tensor_pre_hooks(idx: int, grad: Tensor) -> Tensor: ...

torch/_dynamo/compiled_autograd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,18 @@ def tensor_pre_hook(self, inputs, hook_id, i: int):
734734
self.bind_objects_to_proxies([inputs[i]], [proxy])
735735
return inputs
736736

737+
def cpp_tensor_pre_hook(self, inputs: list[torch.Tensor], hook_id: int, i: int):
738+
proxy = self.fx_tracer.create_proxy(
739+
"call_function",
740+
torch._C._dynamo.compiled_autograd.call_cpp_tensor_pre_hooks,
741+
(hook_id, self.to_proxy(inputs[i])),
742+
{},
743+
)
744+
with disable_proxy_modes_tracing():
745+
inputs[i] = maybe_clone(inputs[i])
746+
self.bind_objects_to_proxies([inputs[i]], [proxy])
747+
return inputs
748+
737749
def pre_hook(self, inputs, hook_id):
738750
assert self.hooks_proxy is not None
739751
hook = self.hooks_proxy[hook_id] # type: ignore[index]

torch/csrc/autograd/cpp_hook.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,9 @@ variable_list CppFunctionSingleTensorPreHook::operator()(
6464
return results;
6565
}
6666

67+
void CppFunctionSingleTensorPreHook::compiled_args(
68+
torch::dynamo::autograd::CompiledNodeArgs& args) const {
69+
args.add_cpp_single_tensor_pre_hook(hook_, value_idx_);
70+
}
71+
6772
} // namespace torch::autograd

torch/csrc/autograd/cpp_hook.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ struct CppFunctionSingleTensorPreHook : public FunctionPreHook {
2222
size_t value_idx);
2323
variable_list operator()(const variable_list& values) override;
2424

25+
void compiled_args(
26+
torch::dynamo::autograd::CompiledNodeArgs& args) const override;
27+
2528
std::function<at::TensorBase(const at::TensorBase&)> hook_;
2629
size_t value_idx_;
2730
};

torch/csrc/dynamo/compiled_autograd.h

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ struct NodeCall {
164164
uint32_t id;
165165
std::shared_ptr<Node> node;
166166
std::vector<std::pair<int, int>> tensor_pre_hooks;
167+
std::vector<std::pair<int, int>> cpp_tensor_pre_hooks;
167168
std::vector<int> pre_hooks;
168169
std::vector<int> post_hooks;
169170
std::vector<int> post_acc_grad_hooks;
@@ -333,6 +334,12 @@ struct AutogradCompilerCall {
333334
return hooks.size() - 1;
334335
}
335336

337+
size_t emplace_cpp_tensor_pre_hook(
338+
std::function<at::TensorBase(const at::TensorBase&)>&& fn) {
339+
cpp_tensor_pre_hooks.emplace_back(std::move(fn));
340+
return cpp_tensor_pre_hooks.size() - 1;
341+
}
342+
336343
size_t emplace_packed_input(c10::SafePyObject&& input) {
337344
packed_inputs.emplace_back(std::move(input));
338345
return packed_inputs.size() - 1;
@@ -348,6 +355,8 @@ struct AutogradCompilerCall {
348355
LiftedIValueArgs lifted_ivalue_args;
349356
std::vector<int64_t> dyn_size_inputs;
350357
std::vector<c10::SafePyObject> hooks;
358+
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
359+
cpp_tensor_pre_hooks;
351360
std::vector<c10::SafePyObject> packed_inputs;
352361
NodeCalls node_calls;
353362
SizeInput::DynType default_dyn_type;
@@ -602,12 +611,12 @@ class CompiledNodeArgs {
602611
#undef COLLECT_AS_BYTES
603612

604613
void collect_hooks_from(Node* fn) {
605-
TORCH_CHECK(
606-
fn->retains_grad_hooks().empty(),
607-
"retains_grad_hooks not implemented for compiled autograd");
608614
for (auto& i : fn->tensor_pre_hooks()) {
609615
i->compiled_args(*this);
610616
}
617+
for (auto& [_, i] : fn->retains_grad_hooks()) {
618+
i->compiled_args(*this);
619+
}
611620
for (auto& i : fn->pre_hooks()) {
612621
i->compiled_args(*this);
613622
}
@@ -647,6 +656,23 @@ class CompiledNodeArgs {
647656
_node_call.tensor_pre_hooks.emplace_back(fn_id, index);
648657
}
649658

659+
void add_cpp_single_tensor_pre_hook(
660+
const std::function<at::TensorBase(const at::TensorBase&)>& hook,
661+
size_t idx) {
662+
auto wrapper = [hook](const at::TensorBase& grad) {
663+
// handle when hook returns nothing
664+
auto out = hook(grad);
665+
if (!out.defined()) {
666+
return grad;
667+
}
668+
return out;
669+
};
670+
671+
auto hook_id = _compiler.emplace_cpp_tensor_pre_hook(std::move(wrapper));
672+
collect_size(hook_id);
673+
_node_call.cpp_tensor_pre_hooks.emplace_back(hook_id, idx);
674+
}
675+
650676
void add_pre_hook(c10::SafePyObject&& obj) {
651677
auto fn_id = _compiler.emplace_hook(std::move(obj));
652678
collect_size(fn_id);

torch/csrc/dynamo/python_compiled_autograd.cpp

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,63 @@ int default_dyn_type_int = 0;
5858
PyObject* python_verbose_logger = nullptr;
5959
} // namespace
6060

61+
// see https://github.com/pytorch/pytorch/pull/34845
62+
static void throw_python_error() {
63+
python_error err;
64+
err.persist();
65+
throw std::move(err);
66+
}
67+
68+
// RuntimeState contains arbitrary callables created during the forward pass.
69+
// e.g. .retains_grad(). It is created during the compiled_args stage, and is
70+
// used at runtime. The lifetime of RuntimeState is a single backward pass.
71+
struct RuntimeState {
72+
at::TensorBase call_cpp_tensor_pre_hooks(
73+
size_t idx,
74+
const at::TensorBase& grad) {
75+
TORCH_INTERNAL_ASSERT(
76+
cpp_tensor_pre_hooks.size() > static_cast<size_t>(idx));
77+
return cpp_tensor_pre_hooks[idx](grad);
78+
}
79+
80+
std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
81+
cpp_tensor_pre_hooks;
82+
size_t next_id = 0;
83+
};
84+
85+
static RuntimeState* active_rstate;
86+
struct RuntimeStateGuard {
87+
RuntimeStateGuard() : _state(std::make_unique<RuntimeState>()) {
88+
active_rstate = _state.get();
89+
}
90+
RuntimeStateGuard(const RuntimeStateGuard&) = delete;
91+
RuntimeStateGuard& operator=(const RuntimeStateGuard&) = delete;
92+
RuntimeStateGuard(RuntimeStateGuard&&) = delete;
93+
RuntimeStateGuard& operator=(RuntimeStateGuard&&) = delete;
94+
95+
~RuntimeStateGuard() {
96+
active_rstate = nullptr;
97+
}
98+
99+
std::unique_ptr<RuntimeState> _state;
100+
};
101+
102+
static PyObject* call_cpp_tensor_pre_hooks(PyObject* dummy, PyObject* args) {
103+
HANDLE_TH_ERRORS;
104+
int idx = -1;
105+
PyObject* grad = nullptr;
106+
if (!PyArg_ParseTuple(args, "iO", &idx, &grad)) {
107+
throw_python_error();
108+
}
109+
TORCH_INTERNAL_ASSERT(idx > -1);
110+
TORCH_INTERNAL_ASSERT(grad != nullptr);
111+
TORCH_INTERNAL_ASSERT(active_rstate != nullptr);
112+
auto res = active_rstate->call_cpp_tensor_pre_hooks(
113+
static_cast<size_t>(idx), THPVariable_Unpack(grad));
114+
return THPVariable_Wrap(res);
115+
END_HANDLE_TH_ERRORS;
116+
}
117+
61118
// List[Optional[Tensor]] in Python can't be directly parsed into a
62119
// List[Tensor], so we need to do this conversion manually.
63120
static std::vector<at::Tensor> toTensorList(
@@ -253,13 +310,6 @@ static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
253310
return pyinput;
254311
}
255312

256-
// see https://github.com/pytorch/pytorch/pull/34845
257-
static void throw_python_error() {
258-
python_error err;
259-
err.persist();
260-
throw std::move(err);
261-
}
262-
263313
static PyObject* check(PyObject* pyresult) {
264314
if (C10_UNLIKELY(pyresult == nullptr)) {
265315
throw_python_error();
@@ -608,6 +658,10 @@ static PyMethodDef _methods[] = {
608658
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
609659
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
610660
{"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
661+
{"call_cpp_tensor_pre_hooks",
662+
call_cpp_tensor_pre_hooks,
663+
METH_VARARGS,
664+
nullptr},
611665
{nullptr, nullptr, 0, nullptr}};
612666

613667
static struct PyModuleDef _module = {
@@ -827,7 +881,8 @@ static CacheNode* _compiled_autograd_impl(
827881
THPObjectPtr* graph_arg_sizes,
828882
THPObjectPtr* graph_arg_ivalue_args,
829883
THPObjectPtr* graph_arg_hooks,
830-
THPObjectPtr* graph_arg_packed_inputs) {
884+
THPObjectPtr* graph_arg_packed_inputs,
885+
RuntimeState* rstate) {
831886
const std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
832887
std::unordered_map<Node*, int> visited_dependencies;
833888
visited_dependencies.reserve(dependencies.size());
@@ -963,6 +1018,20 @@ static CacheNode* _compiled_autograd_impl(
9631018
}
9641019
inputs = THPVariable_UnpackList(pyinputs);
9651020
}
1021+
if (!call.cpp_tensor_pre_hooks.empty()) {
1022+
// proxy a call to runtimestate
1023+
THPObjectPtr pyinputs(THPVariable_WrapList(inputs));
1024+
for (const auto& [hook_id, idx] : call.cpp_tensor_pre_hooks) {
1025+
pyinputs = check(PyObject_CallMethod(
1026+
py_compiler,
1027+
"cpp_tensor_pre_hook",
1028+
"Oii",
1029+
pyinputs.get(),
1030+
hook_id,
1031+
idx));
1032+
}
1033+
inputs = THPVariable_UnpackList(pyinputs);
1034+
}
9661035
for (const auto& graph_output : call.graph_output) {
9671036
int input_nr = graph_output.first;
9681037
int output_index = graph_output.second;
@@ -1090,6 +1159,7 @@ static CacheNode* _compiled_autograd_impl(
10901159
wrap_lifted_ivalue_args(compiler_call.lifted_ivalue_args.args);
10911160
*graph_arg_hooks = convert_pyobj_list(compiler_call.hooks);
10921161
*graph_arg_packed_inputs = convert_pyobj_list(compiler_call.packed_inputs);
1162+
rstate->cpp_tensor_pre_hooks = std::move(compiler_call.cpp_tensor_pre_hooks);
10931163
return cache;
10941164
}
10951165

@@ -1125,6 +1195,7 @@ static variable_list compiled_autograd(
11251195
LockGuardWithErrorLogs lock_guard(mtx);
11261196
pybind11::gil_scoped_acquire gil;
11271197
at::ThreadLocalStateGuard tls_guard(graph_task.thread_locals_);
1198+
RuntimeStateGuard rstate_guard;
11281199

11291200
THPObjectPtr inputs;
11301201
THPObjectPtr sizes;
@@ -1140,7 +1211,8 @@ static variable_list compiled_autograd(
11401211
&sizes,
11411212
&ivalue_args,
11421213
&hooks,
1143-
&packed_inputs);
1214+
&packed_inputs,
1215+
active_rstate);
11441216

11451217
THPObjectPtr pyresult(check(PyObject_CallFunctionObjArgs(
11461218
cache->runtime_wrapper.get(),

0 commit comments

Comments
 (0)