@@ -58,6 +58,63 @@ int default_dyn_type_int = 0;
5858PyObject* python_verbose_logger = nullptr ;
5959} // namespace
6060
61+ // see https://github.com/pytorch/pytorch/pull/34845
62+ static void throw_python_error () {
63+ python_error err;
64+ err.persist ();
65+ throw std::move (err);
66+ }
67+
68+ // RuntimeState contains arbitrary callables created during the forward pass.
69+ // e.g. .retains_grad(). It is created during the compiled_args stage, and is
70+ // used at runtime. The lifetime of RuntimeState is a single backward pass.
71+ struct RuntimeState {
72+ at::TensorBase call_cpp_tensor_pre_hooks (
73+ size_t idx,
74+ const at::TensorBase& grad) {
75+ TORCH_INTERNAL_ASSERT (
76+ cpp_tensor_pre_hooks.size () > static_cast <size_t >(idx));
77+ return cpp_tensor_pre_hooks[idx](grad);
78+ }
79+
80+ std::vector<std::function<at::TensorBase(const at::TensorBase&)>>
81+ cpp_tensor_pre_hooks;
82+ size_t next_id = 0 ;
83+ };
84+
85+ static RuntimeState* active_rstate;
86+ struct RuntimeStateGuard {
87+ RuntimeStateGuard () : _state(std::make_unique<RuntimeState>()) {
88+ active_rstate = _state.get ();
89+ }
90+ RuntimeStateGuard (const RuntimeStateGuard&) = delete ;
91+ RuntimeStateGuard& operator =(const RuntimeStateGuard&) = delete ;
92+ RuntimeStateGuard (RuntimeStateGuard&&) = delete ;
93+ RuntimeStateGuard& operator =(RuntimeStateGuard&&) = delete ;
94+
95+ ~RuntimeStateGuard () {
96+ active_rstate = nullptr ;
97+ }
98+
99+ std::unique_ptr<RuntimeState> _state;
100+ };
101+
102+ static PyObject* call_cpp_tensor_pre_hooks (PyObject* dummy, PyObject* args) {
103+ HANDLE_TH_ERRORS;
104+ int idx = -1 ;
105+ PyObject* grad = nullptr ;
106+ if (!PyArg_ParseTuple (args, " iO" , &idx, &grad)) {
107+ throw_python_error ();
108+ }
109+ TORCH_INTERNAL_ASSERT (idx > -1 );
110+ TORCH_INTERNAL_ASSERT (grad != nullptr );
111+ TORCH_INTERNAL_ASSERT (active_rstate != nullptr );
112+ auto res = active_rstate->call_cpp_tensor_pre_hooks (
113+ static_cast <size_t >(idx), THPVariable_Unpack (grad));
114+ return THPVariable_Wrap (res);
115+ END_HANDLE_TH_ERRORS;
116+ }
117+
61118// List[Optional[Tensor]] in Python can't be directly parsed into a
62119// List[Tensor], so we need to do this conversion manually.
63120static std::vector<at::Tensor> toTensorList (
@@ -253,13 +310,6 @@ static PyObject* convert_pyobj_list(std::vector<c10::SafePyObject>& inputs) {
253310 return pyinput;
254311}
255312
256- // see https://github.com/pytorch/pytorch/pull/34845
257- static void throw_python_error () {
258- python_error err;
259- err.persist ();
260- throw std::move (err);
261- }
262-
263313static PyObject* check (PyObject* pyresult) {
264314 if (C10_UNLIKELY (pyresult == nullptr )) {
265315 throw_python_error ();
@@ -608,6 +658,10 @@ static PyMethodDef _methods[] = {
608658 {" clear_cache" , clear_cache, METH_NOARGS, nullptr },
609659 {" is_cache_empty" , is_cache_empty, METH_NOARGS, nullptr },
610660 {" set_verbose_logger" , set_verbose_logger, METH_VARARGS, nullptr },
661+ {" call_cpp_tensor_pre_hooks" ,
662+ call_cpp_tensor_pre_hooks,
663+ METH_VARARGS,
664+ nullptr },
611665 {nullptr , nullptr , 0 , nullptr }};
612666
613667static struct PyModuleDef _module = {
@@ -827,7 +881,8 @@ static CacheNode* _compiled_autograd_impl(
827881 THPObjectPtr* graph_arg_sizes,
828882 THPObjectPtr* graph_arg_ivalue_args,
829883 THPObjectPtr* graph_arg_hooks,
830- THPObjectPtr* graph_arg_packed_inputs) {
884+ THPObjectPtr* graph_arg_packed_inputs,
885+ RuntimeState* rstate) {
831886 const std::unordered_map<Node*, int >& dependencies = graph_task.dependencies_ ;
832887 std::unordered_map<Node*, int > visited_dependencies;
833888 visited_dependencies.reserve (dependencies.size ());
@@ -963,6 +1018,20 @@ static CacheNode* _compiled_autograd_impl(
9631018 }
9641019 inputs = THPVariable_UnpackList (pyinputs);
9651020 }
1021+ if (!call.cpp_tensor_pre_hooks .empty ()) {
1022+ // proxy a call to runtimestate
1023+ THPObjectPtr pyinputs (THPVariable_WrapList (inputs));
1024+ for (const auto & [hook_id, idx] : call.cpp_tensor_pre_hooks ) {
1025+ pyinputs = check (PyObject_CallMethod (
1026+ py_compiler,
1027+ " cpp_tensor_pre_hook" ,
1028+ " Oii" ,
1029+ pyinputs.get (),
1030+ hook_id,
1031+ idx));
1032+ }
1033+ inputs = THPVariable_UnpackList (pyinputs);
1034+ }
9661035 for (const auto & graph_output : call.graph_output ) {
9671036 int input_nr = graph_output.first ;
9681037 int output_index = graph_output.second ;
@@ -1090,6 +1159,7 @@ static CacheNode* _compiled_autograd_impl(
10901159 wrap_lifted_ivalue_args (compiler_call.lifted_ivalue_args .args );
10911160 *graph_arg_hooks = convert_pyobj_list (compiler_call.hooks );
10921161 *graph_arg_packed_inputs = convert_pyobj_list (compiler_call.packed_inputs );
1162+ rstate->cpp_tensor_pre_hooks = std::move (compiler_call.cpp_tensor_pre_hooks );
10931163 return cache;
10941164}
10951165
@@ -1125,6 +1195,7 @@ static variable_list compiled_autograd(
11251195 LockGuardWithErrorLogs lock_guard (mtx);
11261196 pybind11::gil_scoped_acquire gil;
11271197 at::ThreadLocalStateGuard tls_guard (graph_task.thread_locals_ );
1198+ RuntimeStateGuard rstate_guard;
11281199
11291200 THPObjectPtr inputs;
11301201 THPObjectPtr sizes;
@@ -1140,7 +1211,8 @@ static variable_list compiled_autograd(
11401211 &sizes,
11411212 &ivalue_args,
11421213 &hooks,
1143- &packed_inputs);
1214+ &packed_inputs,
1215+ active_rstate);
11441216
11451217 THPObjectPtr pyresult (check (PyObject_CallFunctionObjArgs (
11461218 cache->runtime_wrapper .get (),
0 commit comments